123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553 |
- import sys
- import os
- import json
- import cv2
- import numpy as np
- from skimage.measure import find_contours
- from io import BytesIO
- import pydicom
- from PyQt5.QtWidgets import (
- QApplication, QMainWindow, QLabel, QPushButton, QColorDialog, QInputDialog,
- QFileDialog, QMessageBox, QVBoxLayout, QHBoxLayout, QWidget, QFrame, QSizePolicy, QComboBox
- )
- from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QColor
- from PyQt5.QtCore import Qt, QPoint, QRect
- from PyQt5.QtWidgets import QSplitter, QGroupBox, QVBoxLayout
- from PyQt5.QtWidgets import QSlider
- class ImageLabel(QLabel):
- def __init__(self, parent=None):
- super().__init__(parent)
- self.setMouseTracking(True)
- self.image = QImage()
- self.rois = {}
- self.current_roi = []
- self.drawing = False
- self.current_color = QColor(Qt.red)
- self.current_label = "ROI"
- self.undo_stack = []
- self.slice_index = 0
- self.dicom_files = []
- self.dataset = None
- self.mark_area_rect = QRect(10, 10, 30, 30)
- self.marked_by_rule_slices = set()
- # Zoom parameters
- self.zoom_factor = 1.0
- self.zoom_step = 0.1
- self.min_zoom = 0.1
- self.max_zoom = 5.0
- def load_dicom_series(self, folder_path):
- # Load all DICOM files from the folder
- self.dicom_files = [os.path.join(folder_path, f) for f in
- os.listdir(folder_path)] # if f.lower().endswith('.dcm')
- if not self.dicom_files:
- raise ValueError("No DICOM files"
- " found in the selected folder.")
- # Sort files by InstanceNumber or SliceLocation
- # self.dicom_files.sort(key=lambda f: int(pydicom.dcmread(f).InstanceNumber))
- self.slice_index = 0
- self.rois.clear()
- self.load_slice()
- def load_slice(self):
- if 0 <= self.slice_index < len(self.dicom_files):
- self.dataset = pydicom.dcmread(self.dicom_files[self.slice_index], force=True)
- # print(self.dataset['0008', '103e'])
- print(self.dataset)
- #
- # private_tag = self.dataset.get((0x0021, 0x1101))
- # data = private_tag.value
- # if isinstance(data, bytes):
- # text = data.decode("latin-1") # Siemens часто кодирует так
- # else:
- # text = str(data)
- # print(text)
- #
- # if 'pd_tse_fs_sag' in self.dataset['0008', '103e'][0:9]:
- pixel_array = self.dataset.pixel_array
- print(f'This file comprises {pixel_array.shape[0]} slices. File"s name is ', self.dataset['0008', '103e'],
- self.dicom_files[self.slice_index])
- image_normalized = cv2.normalize(pixel_array, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
- image_normalized = image_normalized.astype(np.uint8)
- # Normalize pixel array to 8-bit
- image = QImage(image_normalized.data, pixel_array.shape[1], pixel_array.shape[0],
- QImage.Format_Grayscale8)
- self.image = image
- self.setPixmap(QPixmap.fromImage(self.image))
- self.update()
- # else:
- # pass
- def next_slice(self):
- if self.slice_index < len(self.dicom_files) - 1:
- self.slice_index += 1
- self.load_slice()
- def previous_slice(self):
- if self.slice_index > 0:
- self.slice_index -= 1
- self.load_slice()
- def set_current_color(self, color):
- self.current_color = color
- def set_current_label(self, label):
- self.current_label = label
- def mousePressEvent(self, event):
- if event.button() == Qt.LeftButton:
- scaled_pos = event.pos() / self.zoom_factor
- if self.mark_area_rect.contains(scaled_pos):
- current_index = self.slice_index
- if current_index not in self.marked_by_rule_slices:
- self.mark_range_from_current()
- self.update()
- return
- def mouseMoveEvent(self, event):
- if self.drawing:
- self.current_roi.append(event.pos())
- self.update()
- def mouseReleaseEvent(self, event):
- if event.button() == Qt.LeftButton and self.drawing:
- self.drawing = False
- if len(self.current_roi) > 2:
- if self.slice_index not in self.rois:
- self.rois[self.slice_index] = []
- roi = {
- 'label': self.current_label,
- 'color': self.current_color,
- 'points': self.current_roi
- }
- self.rois[self.slice_index].append(roi)
- self.undo_stack.append((self.slice_index, roi))
- self.current_roi = []
- self.update()
- def paintEvent(self, event):
- super().paintEvent(event)
- if not self.image.isNull():
- painter = QPainter(self)
- # Apply zoom transformation
- transform = painter.transform()
- transform.scale(self.zoom_factor, self.zoom_factor)
- painter.setTransform(transform)
- painter.drawImage(self.rect(), self.image, self.image.rect())
- pen = QPen(Qt.red, 2, Qt.SolidLine)
- painter.setPen(pen)
- if self.slice_index in self.rois:
- for roi in self.rois[self.slice_index]:
- pen.setColor(roi['color'])
- painter.setPen(pen)
- points = [QPoint(p.x(), p.y()) for p in roi['points']]
- painter.drawPolygon(*points)
- if self.current_roi:
- pen.setColor(Qt.blue)
- painter.setPen(pen)
- points = [QPoint(p.x(), p.y()) for p in self.current_roi]
- painter.drawPolyline(*points)
- transform = painter.transform()
- transform.scale(self.zoom_factor, self.zoom_factor)
- painter.setTransform(transform)
- pen = QPen(Qt.green, 2)
- painter.setPen(pen)
- painter.setBrush(QColor(0, 255, 0, 100)) # полупрозрачный зеленый
- painter.drawRect(self.mark_area_rect)
- if self.slice_index in self.marked_by_rule_slices:
- painter.drawLine(self.mark_area_rect.topLeft() + QPoint(5, 10),
- self.mark_area_rect.center() + QPoint(0, 8))
- painter.drawLine(self.mark_area_rect.center() + QPoint(0, 8),
- self.mark_area_rect.topRight() + QPoint(-4, 4))
- painter.end()
- def undo(self):
- if self.undo_stack:
- slice_idx, last_roi = self.undo_stack.pop()
- if slice_idx in self.rois and last_roi in self.rois[slice_idx]:
- self.rois[slice_idx].remove(last_roi)
- if not self.rois[slice_idx]:
- del self.rois[slice_idx]
- self.update()
- def keyPressEvent(self, event):
- if event.button() == Qt.CTRL and Qt.Key_Z:
- self.undo()
- def save_rois(self, file_path):
- roi_data = {}
- for slice_idx, rois in self.rois.items():
- roi_data[slice_idx] = []
- for roi in rois:
- roi_data[slice_idx].append({
- 'label': roi['label'],
- 'color': roi['color'].name(),
- 'points': [(point.x(), point.y()) for point in roi['points']]
- })
- with open(file_path, 'w') as file:
- json.dump(roi_data, file, indent=4)
- def mark_range_from_current(self):
- current_index = self.slice_index
- total_slices = len(self.dicom_files)
- middle_index = total_slices // 2
- if current_index <= middle_index:
- indices_to_mark = range(0, current_index + 1)
- else:
- indices_to_mark = range(current_index, total_slices)
- for idx in indices_to_mark:
- roi = {
- 'label': 'AutoMark',
- 'color': QColor(Qt.green),
- 'points': [
- QPoint(10, 10), QPoint(40, 10),
- QPoint(40, 40), QPoint(10, 40)
- ]
- }
- if idx not in self.rois:
- self.rois[idx] = []
- self.rois[idx].append(roi)
- self.undo_stack.append((idx, roi))
- self.marked_by_rule_slices.add(idx) # Пометили
- self.update()
- def wheelEvent(self, event):
- if event.modifiers() & Qt.ControlModifier:
- # Zooming
- angle = event.angleDelta().y()
- factor = 1.1 if angle > 0 else 0.9
- self.zoom_factor *= factor
- self.update()
- else:
- # Slice navigation
- angle = event.angleDelta().y()
- if angle > 0:
- self.previous_slice()
- else:
- self.next_slice()
- def zoom_in(self):
- if self.zoom_factor < self.max_zoom:
- self.zoom_factor += self.zoom_step
- self.update()
- def zoom_out(self):
- if self.zoom_factor > self.min_zoom:
- self.zoom_factor -= self.zoom_step
- self.update()
- class ROIDrawer(QMainWindow):
- def __init__(self):
- super().__init__()
- self.initUI()
- self.threshold_brightness = 0.5 # Default values
- self.contours_thr = 0.3
- def initUI(self):
- self.setWindowTitle("KneeSeg")
- screen_geometry = QApplication.desktop().availableGeometry()
- screen_width = screen_geometry.width()
- screen_height = screen_geometry.height()
- # Set window to 70% of the screen size
- self.resize(int(screen_width * 0.9), int(screen_height * 0.9))
- self.setFixedSize(self.size())
- # Central widget
- central_widget = QWidget(self)
- self.setCentralWidget(central_widget)
- # Create threshold sliders
- self.brightness_slider = QSlider(Qt.Horizontal)
- self.brightness_slider.setMinimum(1) # 1 corresponds to 0.01
- self.brightness_slider.setMaximum(99) # 99 corresponds to 0.99
- self.brightness_slider.setValue(50) # Default to 0.50
- self.brightness_slider.setTickInterval(1)
- self.brightness_slider.valueChanged.connect(self.update_thresholds)
- self.contour_slider = QSlider(Qt.Horizontal)
- self.contour_slider.setMinimum(1)
- self.contour_slider.setMaximum(99)
- self.contour_slider.setValue(30) # Default to 0.30
- self.contour_slider.setTickInterval(1)
- self.contour_slider.valueChanged.connect(self.update_thresholds)
- # Выпадающий список последовательностей
- central_widget = QWidget(self)
- self.setCentralWidget(central_widget)
- self.sequence_dropdown = QComboBox()
- self.sequence_dropdown.setFixedSize(250, 30)
- self.sequence_dropdown.setStyleSheet("font-size: 14px;")
- # заглушки
- self.sequence_dropdown.addItems(["Sequence A", "Sequence B", "Sequence C"])
- # Add sliders to the UI layout
- threshold_layout = QVBoxLayout()
- brightness_label = QLabel("Порог яркости")
- brightness_label.setStyleSheet("font-size: 14px;font-weight: bold;")
- contour_label = QLabel("Порог площади")
- contour_label.setStyleSheet("font-size: 14px;font-weight: bold;")
- threshold_layout.addWidget(brightness_label)
- threshold_layout.addWidget(self.brightness_slider)
- threshold_layout.addWidget(contour_label)
- threshold_layout.addWidget(self.contour_slider)
- # threshold_layout = QVBoxLayout()
- # threshold_layout.addWidget(QLabel("Brightness Threshold"))
- # threshold_layout.addWidget(self.brightness_slider)
- # threshold_layout.addWidget(QLabel("Contour Threshold"))
- # threshold_layout.addWidget(self.contour_slider)
- main_layout = QVBoxLayout(central_widget)
- top_controls_layout = QHBoxLayout()
- top_controls_layout.addStretch(1) # прижать dropdown вправо
- top_controls_layout.addWidget(self.sequence_dropdown)
- main_layout.addLayout(top_controls_layout)
- main_layout.addLayout(threshold_layout)
- # top_layout = QHBoxLayout()
- # bottom_layout = QHBoxLayout()
- # Splitter
- splitter = QSplitter(Qt.Horizontal)
- # Image display with ROI drawing
- self.image_label = ImageLabel(self)
- self.image_label.setAlignment(Qt.AlignCenter)
- image_frame = self.create_labeled_frame("", self.image_label)
- splitter.addWidget(image_frame)
- # Filtration result display
- self.filtration_label = ImageLabel(self) #QLabel("Filtration results will be displayed here.")
- self.filtration_label.setAlignment(Qt.AlignCenter)
- filtration_frame = self.create_labeled_frame(" ", self.filtration_label)
- splitter.addWidget(filtration_frame)
- # Segmentation result display
- self.segmentation_label = ImageLabel(self) # QLabel("Segmentation results will be displayed here.")
- self.segmentation_label.setAlignment(Qt.AlignCenter)
- # segmentation_frame = self.create_labeled_frame("Image Segmentation ", self.segmentation_label)
- segmentation_frame = self.create_labeled_frame("", self.segmentation_label)
- splitter.addWidget(segmentation_frame)
- #Enable Scaled Contents for Each
- self.image_label.setScaledContents(True)
- self.filtration_label.setScaledContents(True)
- self.segmentation_label.setScaledContents(True)
- # # Set size policies to allow resizing
- self.image_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
- self.filtration_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
- self.segmentation_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
- #
- # # Set initial sizes for splitter sections to be equal
- splitter.setSizes([1, 1, 1])
- splitter.setStretchFactor(0, 1) # First widget
- splitter.setStretchFactor(1, 1) # Second widget
- splitter.setStretchFactor(2, 1) # Third widget
- # Buttons layout
- button_layout = QHBoxLayout()
- # Add splitter to the main layout
- main_layout.addWidget(splitter)
- main_layout.addLayout(button_layout)
- # Buttons
- load_button = QPushButton('Load DICOM ', self)
- load_button.setFixedSize(150, 50)
- load_button.setStyleSheet('QPushButton { font-size: 20px;}')
- load_button.clicked.connect(self.load_dicom_series)
- button_layout.addWidget(load_button)
- prev_button = QPushButton('>>', self) # Previous Slice
- prev_button.setFixedSize(50, 50)
- prev_button.setStyleSheet('QPushButton { font-size: 20px; }')
- prev_button.clicked.connect(self.image_label.previous_slice)
- button_layout.addWidget(prev_button)
- next_button = QPushButton('<<', self) # Next Slice
- next_button.setFixedSize(50, 50)
- next_button.setStyleSheet('QPushButton { font-size: 20px; }')
- next_button.clicked.connect(self.image_label.next_slice)
- button_layout.addWidget(next_button)
- color_button = QPushButton('Select Color', self)
- color_button.setFixedSize(150, 50)
- color_button.setStyleSheet('QPushButton { font-size: 20px; }')
- color_button.clicked.connect(self.select_color)
- button_layout.addWidget(color_button)
- label_button = QPushButton('Set Label', self)
- label_button.setFixedSize(100, 50)
- label_button.setStyleSheet('QPushButton { font-size: 20px; }')
- label_button.clicked.connect(self.set_label)
- button_layout.addWidget(label_button)
- undo_button = QPushButton('Undo', self)
- undo_button.setFixedSize(100, 50)
- undo_button.setStyleSheet('QPushButton { font-size: 20px; }')
- undo_button.clicked.connect(self.image_label.undo)
- button_layout.addWidget(undo_button)
- # Filtration and Segmentation Buttons
- # filtration_button = QPushButton(' Filtration', self)
- # filtration_button.setFixedSize(150, 50)
- # filtration_button.setStyleSheet('QPushButton { font-size: 20px; }')
- # filtration_button.clicked.connect(self.apply_filtration)
- # button_layout.addWidget(filtration_button)
- segmentation_button = QPushButton('Segmentation', self)
- segmentation_button.setFixedSize(150, 50)
- segmentation_button.setStyleSheet('QPushButton { font-size: 20px; }')
- segmentation_button.clicked.connect(lambda: (self.apply_segmentation(), self.apply_filtration()))
- button_layout.addWidget(segmentation_button)
- save_button = QPushButton('Save ROIs', self)
- save_button.setFixedSize(150, 50)
- save_button.setStyleSheet('QPushButton { font-size: 20px; }')
- save_button.clicked.connect(self.save_rois)
- button_layout.addWidget(save_button)
- button_layout.addStretch(1) # Push buttons to the top
- def update_thresholds(self):
- """Update segmentation thresholds based on slider values."""
- self.threshold_brightness = self.brightness_slider.value() / 100
- self.contours_thr = self.contour_slider.value() / 100
- def create_labeled_frame(self, title, widget):
- frame = QFrame()
- layout = QVBoxLayout()
- label = QLabel(f"<b>{title}</b>")
- label.setAlignment(Qt.AlignCenter)
- layout.addWidget(label)
- layout.addWidget(widget)
- frame.setLayout(layout)
- return frame
- def load_dicom_series(self):
- options = QFileDialog.Options()
- folder_path = QFileDialog.getExistingDirectory(
- self, "Select DICOM Series Folder", options=options)
- if folder_path:
- try:
- self.image_label.load_dicom_series(folder_path)
- except ValueError as e:
- QMessageBox.critical(self, "Error", str(e))
- def select_color(self):
- color = QColorDialog.getColor()
- if color.isValid():
- self.image_label.set_current_color(color)
- def set_label(self):
- label, ok = QInputDialog.getText(self, 'Set ROI Label', 'Enter label for ROI:')
- if ok and label:
- self.image_label.set_current_label(label)
- def save_rois(self):
- options = QFileDialog.Options()
- file_path, _ = QFileDialog.getSaveFileName(
- self, "Save ROIs", "", "JSON Files (*.json);;All Files (*)", options=options)
- if file_path:
- self.image_label.save_rois(file_path)
- def apply_filtration(self):
- if self.image_label.image.isNull():
- QMessageBox.warning(self, "Warning", "No image loaded.")
- return
- # Convert QImage to numpy array
- image = self.image_label.image
- width = image.width()
- height = image.height()
- ptr = image.bits()
- ptr.setsize(image.byteCount())
- arr = np.array(ptr).reshape(height, width, 1) # Assuming grayscale
- # Apply user-defined thresholds
- threshold_brightness = self.threshold_brightness * np.max(arr)
- segmented_im = np.copy(arr)
- segmented_im[segmented_im < threshold_brightness] = 0
- segmented_im[segmented_im >= threshold_brightness] = 255 # Binary mask
- # Find contours in the binary mask
- contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- # Convert grayscale image to BGR for color overlay
- color_image = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
- # Draw contours on the color image
- cv2.drawContours(color_image, contours, -1, (255, 0, 0), 2) # Red color, thickness 2
- # Convert back to QImage
- q_image = QImage(color_image.data, color_image.shape[1], color_image.shape[0], QImage.Format_RGB888)
- self.filtration_label.setPixmap(QPixmap.fromImage(q_image))
- def apply_segmentation(self):
- if self.image_label.image.isNull():
- QMessageBox.warning(self, "Warning", "No image loaded.")
- return
- # Convert QImage to numpy array
- image = self.image_label.image
- width, height = image.width(), image.height()
- ptr = image.bits()
- ptr.setsize(image.byteCount())
- arr = np.array(ptr).reshape(height, width, 1) # Assuming grayscale
- # Apply user-defined thresholds
- threshold_brightness = self.threshold_brightness * np.max(arr)
- segmented_im = np.copy(arr)
- segmented_im[segmented_im < threshold_brightness] = 0
- segmented_im[segmented_im >= threshold_brightness] = 255 # Binary mask
- # Find contours
- # contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
- contours_square = [cv2.contourArea(cnt) for cnt in contours]
- # Apply user-defined contour threshold
- if contours_square:
- contours_thr = self.contours_thr * np.max(contours_square)
- else:
- contours_thr = 0
- filtered_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > contours_thr]
- # Convert grayscale image to BGR for color overlay
- color_image = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
- cv2.drawContours(color_image, filtered_contours, -1, (0, 255, 255), 2) # Draw yellow contours
- # Convert back to QImage
- q_image = QImage(color_image.data, color_image.shape[1], color_image.shape[0], QImage.Format_RGB888)
- self.segmentation_label.setPixmap(QPixmap.fromImage(q_image))
- if __name__ == '__main__':
- app = QApplication(sys.argv)
- drawer = ROIDrawer()
- drawer.show()
- sys.exit(app.exec_())
|