import sys import os import json import cv2 import numpy as np from skimage.measure import find_contours from io import BytesIO import pydicom from PyQt5.QtWidgets import ( QApplication, QMainWindow, QLabel, QPushButton, QColorDialog, QInputDialog, QFileDialog, QMessageBox, QVBoxLayout, QHBoxLayout, QWidget, QFrame, QSizePolicy, QComboBox ) from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QColor from PyQt5.QtCore import Qt, QPoint, QRect from PyQt5.QtWidgets import QSplitter, QGroupBox, QVBoxLayout from PyQt5.QtWidgets import QSlider class ImageLabel(QLabel): def __init__(self, parent=None): super().__init__(parent) self.setMouseTracking(True) self.image = QImage() self.rois = {} self.current_roi = [] self.drawing = False self.current_color = QColor(Qt.red) self.current_label = "ROI" self.undo_stack = [] self.slice_index = 0 self.dicom_files = [] self.dataset = None self.mark_area_rect = QRect(10, 10, 30, 30) self.marked_by_rule_slices = set() # Zoom parameters self.zoom_factor = 1.0 self.zoom_step = 0.1 self.min_zoom = 0.1 self.max_zoom = 5.0 def load_dicom_series(self, folder_path): # Load all DICOM files from the folder self.dicom_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path)] # if f.lower().endswith('.dcm') if not self.dicom_files: raise ValueError("No DICOM files" " found in the selected folder.") # Sort files by InstanceNumber or SliceLocation # self.dicom_files.sort(key=lambda f: int(pydicom.dcmread(f).InstanceNumber)) self.slice_index = 0 self.rois.clear() self.load_slice() def load_slice(self): if 0 <= self.slice_index < len(self.dicom_files): self.dataset = pydicom.dcmread(self.dicom_files[self.slice_index], force=True) # print(self.dataset['0008', '103e']) print(self.dataset) # # private_tag = self.dataset.get((0x0021, 0x1101)) # data = private_tag.value # if isinstance(data, bytes): # text = data.decode("latin-1") # Siemens часто кодирует так # else: # text = str(data) # print(text) # # if 'pd_tse_fs_sag' in self.dataset['0008', '103e'][0:9]: pixel_array = self.dataset.pixel_array print(f'This file comprises {pixel_array.shape[0]} slices. File"s name is ', self.dataset['0008', '103e'], self.dicom_files[self.slice_index]) image_normalized = cv2.normalize(pixel_array, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX) image_normalized = image_normalized.astype(np.uint8) # Normalize pixel array to 8-bit image = QImage(image_normalized.data, pixel_array.shape[1], pixel_array.shape[0], QImage.Format_Grayscale8) self.image = image self.setPixmap(QPixmap.fromImage(self.image)) self.update() # else: # pass def next_slice(self): if self.slice_index < len(self.dicom_files) - 1: self.slice_index += 1 self.load_slice() def previous_slice(self): if self.slice_index > 0: self.slice_index -= 1 self.load_slice() def set_current_color(self, color): self.current_color = color def set_current_label(self, label): self.current_label = label def mousePressEvent(self, event): if event.button() == Qt.LeftButton: scaled_pos = event.pos() / self.zoom_factor if self.mark_area_rect.contains(scaled_pos): current_index = self.slice_index if current_index not in self.marked_by_rule_slices: self.mark_range_from_current() self.update() return def mouseMoveEvent(self, event): if self.drawing: self.current_roi.append(event.pos()) self.update() def mouseReleaseEvent(self, event): if event.button() == Qt.LeftButton and self.drawing: self.drawing = False if len(self.current_roi) > 2: if self.slice_index not in self.rois: self.rois[self.slice_index] = [] roi = { 'label': self.current_label, 'color': self.current_color, 'points': self.current_roi } self.rois[self.slice_index].append(roi) self.undo_stack.append((self.slice_index, roi)) self.current_roi = [] self.update() def paintEvent(self, event): super().paintEvent(event) if not self.image.isNull(): painter = QPainter(self) # Apply zoom transformation transform = painter.transform() transform.scale(self.zoom_factor, self.zoom_factor) painter.setTransform(transform) painter.drawImage(self.rect(), self.image, self.image.rect()) pen = QPen(Qt.red, 2, Qt.SolidLine) painter.setPen(pen) if self.slice_index in self.rois: for roi in self.rois[self.slice_index]: pen.setColor(roi['color']) painter.setPen(pen) points = [QPoint(p.x(), p.y()) for p in roi['points']] painter.drawPolygon(*points) if self.current_roi: pen.setColor(Qt.blue) painter.setPen(pen) points = [QPoint(p.x(), p.y()) for p in self.current_roi] painter.drawPolyline(*points) transform = painter.transform() transform.scale(self.zoom_factor, self.zoom_factor) painter.setTransform(transform) pen = QPen(Qt.green, 2) painter.setPen(pen) painter.setBrush(QColor(0, 255, 0, 100)) # полупрозрачный зеленый painter.drawRect(self.mark_area_rect) if self.slice_index in self.marked_by_rule_slices: painter.drawLine(self.mark_area_rect.topLeft() + QPoint(5, 10), self.mark_area_rect.center() + QPoint(0, 8)) painter.drawLine(self.mark_area_rect.center() + QPoint(0, 8), self.mark_area_rect.topRight() + QPoint(-4, 4)) painter.end() def undo(self): if self.undo_stack: slice_idx, last_roi = self.undo_stack.pop() if slice_idx in self.rois and last_roi in self.rois[slice_idx]: self.rois[slice_idx].remove(last_roi) if not self.rois[slice_idx]: del self.rois[slice_idx] self.update() def keyPressEvent(self, event): if event.button() == Qt.CTRL and Qt.Key_Z: self.undo() def save_rois(self, file_path): roi_data = {} for slice_idx, rois in self.rois.items(): roi_data[slice_idx] = [] for roi in rois: roi_data[slice_idx].append({ 'label': roi['label'], 'color': roi['color'].name(), 'points': [(point.x(), point.y()) for point in roi['points']] }) with open(file_path, 'w') as file: json.dump(roi_data, file, indent=4) def mark_range_from_current(self): current_index = self.slice_index total_slices = len(self.dicom_files) middle_index = total_slices // 2 if current_index <= middle_index: indices_to_mark = range(0, current_index + 1) else: indices_to_mark = range(current_index, total_slices) for idx in indices_to_mark: roi = { 'label': 'AutoMark', 'color': QColor(Qt.green), 'points': [ QPoint(10, 10), QPoint(40, 10), QPoint(40, 40), QPoint(10, 40) ] } if idx not in self.rois: self.rois[idx] = [] self.rois[idx].append(roi) self.undo_stack.append((idx, roi)) self.marked_by_rule_slices.add(idx) # Пометили self.update() def wheelEvent(self, event): if event.modifiers() & Qt.ControlModifier: # Zooming angle = event.angleDelta().y() factor = 1.1 if angle > 0 else 0.9 self.zoom_factor *= factor self.update() else: # Slice navigation angle = event.angleDelta().y() if angle > 0: self.previous_slice() else: self.next_slice() def zoom_in(self): if self.zoom_factor < self.max_zoom: self.zoom_factor += self.zoom_step self.update() def zoom_out(self): if self.zoom_factor > self.min_zoom: self.zoom_factor -= self.zoom_step self.update() class ROIDrawer(QMainWindow): def __init__(self): super().__init__() self.initUI() self.threshold_brightness = 0.5 # Default values self.contours_thr = 0.3 def initUI(self): self.setWindowTitle("KneeSeg") screen_geometry = QApplication.desktop().availableGeometry() screen_width = screen_geometry.width() screen_height = screen_geometry.height() # Set window to 70% of the screen size self.resize(int(screen_width * 0.9), int(screen_height * 0.9)) self.setFixedSize(self.size()) # Central widget central_widget = QWidget(self) self.setCentralWidget(central_widget) # Create threshold sliders self.brightness_slider = QSlider(Qt.Horizontal) self.brightness_slider.setMinimum(1) # 1 corresponds to 0.01 self.brightness_slider.setMaximum(99) # 99 corresponds to 0.99 self.brightness_slider.setValue(50) # Default to 0.50 self.brightness_slider.setTickInterval(1) self.brightness_slider.valueChanged.connect(self.update_thresholds) self.contour_slider = QSlider(Qt.Horizontal) self.contour_slider.setMinimum(1) self.contour_slider.setMaximum(99) self.contour_slider.setValue(30) # Default to 0.30 self.contour_slider.setTickInterval(1) self.contour_slider.valueChanged.connect(self.update_thresholds) # Выпадающий список последовательностей central_widget = QWidget(self) self.setCentralWidget(central_widget) self.sequence_dropdown = QComboBox() self.sequence_dropdown.setFixedSize(250, 30) self.sequence_dropdown.setStyleSheet("font-size: 14px;") # заглушки self.sequence_dropdown.addItems(["Sequence A", "Sequence B", "Sequence C"]) # Add sliders to the UI layout threshold_layout = QVBoxLayout() brightness_label = QLabel("Порог яркости") brightness_label.setStyleSheet("font-size: 14px;font-weight: bold;") contour_label = QLabel("Порог площади") contour_label.setStyleSheet("font-size: 14px;font-weight: bold;") threshold_layout.addWidget(brightness_label) threshold_layout.addWidget(self.brightness_slider) threshold_layout.addWidget(contour_label) threshold_layout.addWidget(self.contour_slider) # threshold_layout = QVBoxLayout() # threshold_layout.addWidget(QLabel("Brightness Threshold")) # threshold_layout.addWidget(self.brightness_slider) # threshold_layout.addWidget(QLabel("Contour Threshold")) # threshold_layout.addWidget(self.contour_slider) main_layout = QVBoxLayout(central_widget) top_controls_layout = QHBoxLayout() top_controls_layout.addStretch(1) # прижать dropdown вправо top_controls_layout.addWidget(self.sequence_dropdown) main_layout.addLayout(top_controls_layout) main_layout.addLayout(threshold_layout) # top_layout = QHBoxLayout() # bottom_layout = QHBoxLayout() # Splitter splitter = QSplitter(Qt.Horizontal) # Image display with ROI drawing self.image_label = ImageLabel(self) self.image_label.setAlignment(Qt.AlignCenter) image_frame = self.create_labeled_frame("", self.image_label) splitter.addWidget(image_frame) # Filtration result display self.filtration_label = ImageLabel(self) #QLabel("Filtration results will be displayed here.") self.filtration_label.setAlignment(Qt.AlignCenter) filtration_frame = self.create_labeled_frame(" ", self.filtration_label) splitter.addWidget(filtration_frame) # Segmentation result display self.segmentation_label = ImageLabel(self) # QLabel("Segmentation results will be displayed here.") self.segmentation_label.setAlignment(Qt.AlignCenter) # segmentation_frame = self.create_labeled_frame("Image Segmentation ", self.segmentation_label) segmentation_frame = self.create_labeled_frame("", self.segmentation_label) splitter.addWidget(segmentation_frame) #Enable Scaled Contents for Each self.image_label.setScaledContents(True) self.filtration_label.setScaledContents(True) self.segmentation_label.setScaledContents(True) # # Set size policies to allow resizing self.image_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) self.filtration_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) self.segmentation_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) # # # Set initial sizes for splitter sections to be equal splitter.setSizes([1, 1, 1]) splitter.setStretchFactor(0, 1) # First widget splitter.setStretchFactor(1, 1) # Second widget splitter.setStretchFactor(2, 1) # Third widget # Buttons layout button_layout = QHBoxLayout() # Add splitter to the main layout main_layout.addWidget(splitter) main_layout.addLayout(button_layout) # Buttons load_button = QPushButton('Load DICOM ', self) load_button.setFixedSize(150, 50) load_button.setStyleSheet('QPushButton { font-size: 20px;}') load_button.clicked.connect(self.load_dicom_series) button_layout.addWidget(load_button) prev_button = QPushButton('>>', self) # Previous Slice prev_button.setFixedSize(50, 50) prev_button.setStyleSheet('QPushButton { font-size: 20px; }') prev_button.clicked.connect(self.image_label.previous_slice) button_layout.addWidget(prev_button) next_button = QPushButton('<<', self) # Next Slice next_button.setFixedSize(50, 50) next_button.setStyleSheet('QPushButton { font-size: 20px; }') next_button.clicked.connect(self.image_label.next_slice) button_layout.addWidget(next_button) color_button = QPushButton('Select Color', self) color_button.setFixedSize(150, 50) color_button.setStyleSheet('QPushButton { font-size: 20px; }') color_button.clicked.connect(self.select_color) button_layout.addWidget(color_button) label_button = QPushButton('Set Label', self) label_button.setFixedSize(100, 50) label_button.setStyleSheet('QPushButton { font-size: 20px; }') label_button.clicked.connect(self.set_label) button_layout.addWidget(label_button) undo_button = QPushButton('Undo', self) undo_button.setFixedSize(100, 50) undo_button.setStyleSheet('QPushButton { font-size: 20px; }') undo_button.clicked.connect(self.image_label.undo) button_layout.addWidget(undo_button) # Filtration and Segmentation Buttons # filtration_button = QPushButton(' Filtration', self) # filtration_button.setFixedSize(150, 50) # filtration_button.setStyleSheet('QPushButton { font-size: 20px; }') # filtration_button.clicked.connect(self.apply_filtration) # button_layout.addWidget(filtration_button) segmentation_button = QPushButton('Segmentation', self) segmentation_button.setFixedSize(150, 50) segmentation_button.setStyleSheet('QPushButton { font-size: 20px; }') segmentation_button.clicked.connect(lambda: (self.apply_segmentation(), self.apply_filtration())) button_layout.addWidget(segmentation_button) save_button = QPushButton('Save ROIs', self) save_button.setFixedSize(150, 50) save_button.setStyleSheet('QPushButton { font-size: 20px; }') save_button.clicked.connect(self.save_rois) button_layout.addWidget(save_button) button_layout.addStretch(1) # Push buttons to the top def update_thresholds(self): """Update segmentation thresholds based on slider values.""" self.threshold_brightness = self.brightness_slider.value() / 100 self.contours_thr = self.contour_slider.value() / 100 def create_labeled_frame(self, title, widget): frame = QFrame() layout = QVBoxLayout() label = QLabel(f"{title}") label.setAlignment(Qt.AlignCenter) layout.addWidget(label) layout.addWidget(widget) frame.setLayout(layout) return frame def load_dicom_series(self): options = QFileDialog.Options() folder_path = QFileDialog.getExistingDirectory( self, "Select DICOM Series Folder", options=options) if folder_path: try: self.image_label.load_dicom_series(folder_path) except ValueError as e: QMessageBox.critical(self, "Error", str(e)) def select_color(self): color = QColorDialog.getColor() if color.isValid(): self.image_label.set_current_color(color) def set_label(self): label, ok = QInputDialog.getText(self, 'Set ROI Label', 'Enter label for ROI:') if ok and label: self.image_label.set_current_label(label) def save_rois(self): options = QFileDialog.Options() file_path, _ = QFileDialog.getSaveFileName( self, "Save ROIs", "", "JSON Files (*.json);;All Files (*)", options=options) if file_path: self.image_label.save_rois(file_path) def apply_filtration(self): if self.image_label.image.isNull(): QMessageBox.warning(self, "Warning", "No image loaded.") return # Convert QImage to numpy array image = self.image_label.image width = image.width() height = image.height() ptr = image.bits() ptr.setsize(image.byteCount()) arr = np.array(ptr).reshape(height, width, 1) # Assuming grayscale # Apply user-defined thresholds threshold_brightness = self.threshold_brightness * np.max(arr) segmented_im = np.copy(arr) segmented_im[segmented_im < threshold_brightness] = 0 segmented_im[segmented_im >= threshold_brightness] = 255 # Binary mask # Find contours in the binary mask contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Convert grayscale image to BGR for color overlay color_image = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR) # Draw contours on the color image cv2.drawContours(color_image, contours, -1, (255, 0, 0), 2) # Red color, thickness 2 # Convert back to QImage q_image = QImage(color_image.data, color_image.shape[1], color_image.shape[0], QImage.Format_RGB888) self.filtration_label.setPixmap(QPixmap.fromImage(q_image)) def apply_segmentation(self): if self.image_label.image.isNull(): QMessageBox.warning(self, "Warning", "No image loaded.") return # Convert QImage to numpy array image = self.image_label.image width, height = image.width(), image.height() ptr = image.bits() ptr.setsize(image.byteCount()) arr = np.array(ptr).reshape(height, width, 1) # Assuming grayscale # Apply user-defined thresholds threshold_brightness = self.threshold_brightness * np.max(arr) segmented_im = np.copy(arr) segmented_im[segmented_im < threshold_brightness] = 0 segmented_im[segmented_im >= threshold_brightness] = 255 # Binary mask # Find contours # contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contours_square = [cv2.contourArea(cnt) for cnt in contours] # Apply user-defined contour threshold if contours_square: contours_thr = self.contours_thr * np.max(contours_square) else: contours_thr = 0 filtered_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > contours_thr] # Convert grayscale image to BGR for color overlay color_image = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR) cv2.drawContours(color_image, filtered_contours, -1, (0, 255, 255), 2) # Draw yellow contours # Convert back to QImage q_image = QImage(color_image.data, color_image.shape[1], color_image.shape[0], QImage.Format_RGB888) self.segmentation_label.setPixmap(QPixmap.fromImage(q_image)) if __name__ == '__main__': app = QApplication(sys.argv) drawer = ROIDrawer() drawer.show() sys.exit(app.exec_())