|
- import sys
- import os
- import json
- import cv2
- import numpy as np
- import pydicom
- from PyQt5.QtWidgets import (
- QApplication, QMainWindow, QLabel, QPushButton, QColorDialog, QInputDialog,
- QFileDialog, QMessageBox, QVBoxLayout, QHBoxLayout, QWidget, QFrame, QSizePolicy,
- QSlider, QComboBox, QCheckBox
- )
- from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QColor
- from PyQt5.QtCore import Qt, QPoint, QRect
- from PyQt5.QtWidgets import QSplitter
- # =============================
- # Unified Image Viewer + ROI
- # =============================
- class ImageLabel(QLabel):
- def __init__(self, parent=None):
- super().__init__(parent)
- self.setMouseTracking(True)
- self.image = QImage()
- self.rois = {}
- self.current_roi = []
- self.drawing = False
- self.current_color = QColor(Qt.red)
- self.current_label = "ROI"
- self.undo_stack = []
- # Series & slice
- self.slice_index = 0
- self.all_dicom_files = []
- self.filtered_series_files = [] # selected series files
- self.dataset = None
- # ROI propagation (from _12/_13)
- self.base_roi = None # first ROI copied to all slices
- # Volume-related spacing
- self.spacing_xy = (0.0, 0.0)
- self.spacing_z = 0.0
- # For volume per slice
- self.volume_by_slice = {}
- # Visual "checkbox" on image (from dicom_labeler.py)
- self.mark_area_rect = QRect(10, 10, 30, 30)
- self.marked_by_rule_slices = set()
- # Zoom
- self.zoom_factor = 1.0
- self.zoom_step = 0.1
- self.min_zoom = 0.1
- self.max_zoom = 5.0
- # Callbacks the window can set to recompute views on change
- self.on_slice_changed = None
- # ------------- Loading logic -------------
- def load_folder(self, folder_path):
- self.all_dicom_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path)]
- if not self.all_dicom_files:
- raise ValueError("В папке нет DICOM-файлов.")
- def filter_series(self, predicate):
- """
- predicate: callable(dataset) -> bool, choose subset by SeriesDescription etc.
- """
- self.filtered_series_files = []
- for f in self.all_dicom_files:
- try:
- ds = pydicom.dcmread(f, stop_before_pixels=True, force=True)
- if predicate(ds):
- self.filtered_series_files.append(f)
- except Exception:
- pass
- if not self.filtered_series_files:
- raise ValueError("Подходящая серия не найдена по выбранному фильтру.")
- # sort by InstanceNumber when possible
- try:
- self.filtered_series_files.sort(key=lambda p: int(pydicom.dcmread(p, stop_before_pixels=True, force=True).get("InstanceNumber", 0)))
- except Exception:
- self.filtered_series_files.sort()
- self.slice_index = len(self.filtered_series_files) // 2
- self.rois.clear()
- self.base_roi = None
- self.volume_by_slice.clear()
- self.load_slice()
- def load_slice(self):
- if 0 <= self.slice_index < len(self.filtered_series_files):
- self.dataset = pydicom.dcmread(self.filtered_series_files[self.slice_index], force=True)
- # spacing
- try:
- px = self.dataset.PixelSpacing
- self.spacing_xy = (float(px[0]), float(px[1]))
- except Exception:
- self.spacing_xy = (1.0, 1.0)
- # z-spacing (will be used as SliceThickness or SpacingBetweenSlices by a toggle in window)
- # Store both to allow switching
- self.slice_thickness = float(getattr(self.dataset, 'SliceThickness', 1.0))
- self.spacing_between_slices = float(getattr(self.dataset, 'SpacingBetweenSlices', self.slice_thickness))
- arr = self.dataset.pixel_array
- img8 = cv2.normalize(arr, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
- self.image = QImage(img8.data, arr.shape[1], arr.shape[0], QImage.Format_Grayscale8)
- # If a base ROI was created earlier, ensure it's present on new slice (copy-on-first-show)
- if self.base_roi is not None and self.slice_index not in self.rois:
- self.rois[self.slice_index] = [self._copy_roi(self.base_roi)]
- self.setPixmap(QPixmap.fromImage(self.image))
- self.update()
- if self.on_slice_changed:
- self.on_slice_changed()
- # ------------- Navigation -------------
- def next_slice(self):
- if self.slice_index < len(self.filtered_series_files) - 1:
- self.slice_index += 1
- self.load_slice()
- def previous_slice(self):
- if self.slice_index > 0:
- self.slice_index -= 1
- self.load_slice()
- # ------------- ROI ops -------------
- def set_current_color(self, color):
- self.current_color = color
- def set_current_label(self, label):
- self.current_label = label
- def _copy_roi(self, roi):
- return {
- 'label': roi['label'],
- 'color': QColor(roi['color']),
- 'points': [QPoint(p.x(), p.y()) for p in roi['points']],
- }
- def _unmark_slice(self, slice_idx):
- """Убирает служебную галку с одного среза."""
- self.marked_by_rule_slices.discard(slice_idx)
- if slice_idx in self.rois:
- self.rois[slice_idx] = [roi for roi in self.rois[slice_idx] if roi.get('label') != 'AutoMark']
- if not self.rois[slice_idx]:
- del self.rois[slice_idx]
- def _unmark_all_slices(self):
- """Убирает все галки и служебные ROI."""
- self.marked_by_rule_slices.clear()
- for idx in list(self.rois.keys()):
- self.rois[idx] = [roi for roi in self.rois[idx] if roi.get('label') != 'AutoMark']
- if not self.rois[idx]:
- del self.rois[idx]
- def mousePressEvent(self, event):
- scaled_pos = event.pos() / self.zoom_factor
- if self.mark_area_rect.contains(scaled_pos):
- idx = self.slice_index
- if event.button() == Qt.LeftButton:
- if idx in self.marked_by_rule_slices:
- self._unmark_slice(idx)
- else:
- self._mark_range_from_current()
- elif event.button() == Qt.RightButton:
- self._unmark_all_slices()
- self.update()
- return
- def mouseMoveEvent(self, event):
- if self.drawing:
- scaled_pos = event.pos() / self.zoom_factor
- self.current_roi.append(scaled_pos)
- self.update()
- def mouseReleaseEvent(self, event):
- if event.button() == Qt.LeftButton and self.drawing:
- self.drawing = False
- if len(self.current_roi) > 2:
- roi = {
- 'label': self.current_label,
- 'color': QColor(self.current_color),
- 'points': [QPoint(p.x(), p.y()) for p in self.current_roi],
- }
- # First ROI becomes base ROI and is added to all slices
- if self.base_roi is None:
- self.base_roi = self._copy_roi(roi)
- for idx in range(len(self.filtered_series_files)):
- self.rois.setdefault(idx, []).append(self._copy_roi(self.base_roi))
- # Also keep the drawn ROI on this slice as an adjustment
- self.rois.setdefault(self.slice_index, []).append(roi)
- self.undo_stack.append((self.slice_index, roi))
- self.current_roi = []
- self.update()
- if self.on_slice_changed:
- self.on_slice_changed()
- # ------------- Painting -------------
- def paintEvent(self, event):
- super().paintEvent(event)
- if self.image.isNull():
- return
- painter = QPainter(self)
- # Zoom transform
- t = painter.transform()
- t.scale(self.zoom_factor, self.zoom_factor)
- painter.setTransform(t)
- painter.drawImage(self.rect(), self.image, self.image.rect())
- pen = QPen(Qt.red, 2)
- painter.setPen(pen)
- # Draw stored ROIs
- for roi in self.rois.get(self.slice_index, []):
- pen.setColor(roi['color'])
- painter.setPen(pen)
- pts = [QPoint(p.x(), p.y()) for p in roi['points']]
- if len(pts) >= 3:
- painter.drawPolygon(*pts)
- # Current polyline
- if self.current_roi:
- pen.setColor(Qt.blue)
- painter.setPen(pen)
- pts = [QPoint(p.x(), p.y()) for p in self.current_roi]
- painter.drawPolyline(*pts)
- # Draw the green checkbox area
- pen = QPen(Qt.green, 2)
- painter.setPen(pen)
- painter.setBrush(QColor(0, 255, 0, 100))
- painter.drawRect(self.mark_area_rect)
- if self.slice_index in self.marked_by_rule_slices:
- painter.drawLine(self.mark_area_rect.topLeft() + QPoint(5, 10),
- self.mark_area_rect.center() + QPoint(0, 8))
- painter.drawLine(self.mark_area_rect.center() + QPoint(0, 8),
- self.mark_area_rect.topRight() + QPoint(-4, 4))
- painter.end()
- # ------------- Undo / Clear -------------
- def undo(self):
- if self.undo_stack:
- slice_idx, last_roi = self.undo_stack.pop()
- if slice_idx in self.rois and last_roi in self.rois[slice_idx]:
- self.rois[slice_idx].remove(last_roi)
- if not self.rois[slice_idx]:
- del self.rois[slice_idx]
- self.update()
- if self.on_slice_changed:
- self.on_slice_changed()
- def clear_slice(self):
- self.rois[self.slice_index] = []
- self.update()
- if self.on_slice_changed:
- self.on_slice_changed()
- # ------------- Auto-mark rule -------------
- def _mark_range_from_current(self):
- current = self.slice_index
- total = len(self.filtered_series_files)
- middle = total // 2
- if current <= middle:
- indices = range(0, current + 1)
- else:
- indices = range(current, total)
- for idx in indices:
- roi = {
- 'label': 'AutoMark',
- 'color': QColor(Qt.green),
- 'points': [QPoint(10, 10), QPoint(40, 10), QPoint(40, 40), QPoint(10, 40)],
- }
- self.rois.setdefault(idx, []).append(self._copy_roi(roi))
- self.undo_stack.append((idx, roi))
- self.marked_by_rule_slices.add(idx)
- self.update()
- # ------------- Wheel: zoom or navigate -------------
- def wheelEvent(self, event):
- if event.modifiers() & Qt.ControlModifier:
- angle = event.angleDelta().y()
- factor = 1.1 if angle > 0 else 0.9
- self.zoom_factor = max(self.min_zoom, min(self.max_zoom, self.zoom_factor * factor))
- self.update()
- else:
- if event.angleDelta().y() > 0:
- self.previous_slice()
- else:
- self.next_slice()
- def keyPressEvent(self, event):
- # Optional Ctrl+Z
- if event.modifiers() & Qt.ControlModifier and event.key() == Qt.Key_Z:
- self.undo()
- # =============================
- # Unified Main Window
- # =============================
- class ROIDrawer(QMainWindow):
- def __init__(self):
- super().__init__()
- self.threshold_brightness = 0.5
- self.contours_thr = 0.3
- self.use_slice_thickness = True # toggle between SliceThickness and SpacingBetweenSlices
- self.series_predicates = [] # (name, callable)
- self._build_ui()
- self._init_series_filters()
- # ---------- UI ----------
- def _build_ui(self):
- self.setWindowTitle("KneeSeg - Unified")
- screen = QApplication.desktop().availableGeometry()
- self.resize(int(screen.width() * 0.95), int(screen.height() * 0.95))
- self.setFixedSize(self.size())
- central = QWidget(self)
- self.setCentralWidget(central)
- main_layout = QVBoxLayout(central)
- # Top row: sequence dropdown + z-spacing toggle
- top_controls = QHBoxLayout()
- self.sequence_dropdown = QComboBox()
- self.sequence_dropdown.setFixedSize(300, 30)
- self.sequence_dropdown.setStyleSheet("font-size: 14px;")
- top_controls.addStretch(1)
- top_controls.addWidget(self.sequence_dropdown)
- self.z_toggle = QCheckBox("Использовать SliceThickness (иначе SpacingBetweenSlices)")
- self.z_toggle.setChecked(True)
- self.z_toggle.stateChanged.connect(self._toggle_z_source)
- top_controls.addWidget(self.z_toggle)
- main_layout.addLayout(top_controls)
- # Threshold controls
- threshold_layout = QVBoxLayout()
- b_lbl = QLabel("Порог яркости")
- b_lbl.setStyleSheet("font-size: 14px;font-weight: bold;")
- self.brightness_slider = QSlider(Qt.Horizontal)
- self.brightness_slider.setMinimum(1)
- self.brightness_slider.setMaximum(99)
- self.brightness_slider.setValue(50)
- self.brightness_slider.valueChanged.connect(self._update_thresholds)
- c_lbl = QLabel("Порог площади")
- c_lbl.setStyleSheet("font-size: 14px;font-weight: bold;")
- self.contour_slider = QSlider(Qt.Horizontal)
- self.contour_slider.setMinimum(1)
- self.contour_slider.setMaximum(99)
- self.contour_slider.setValue(30)
- self.contour_slider.valueChanged.connect(self._update_thresholds)
- threshold_layout.addWidget(b_lbl)
- threshold_layout.addWidget(self.brightness_slider)
- threshold_layout.addWidget(c_lbl)
- threshold_layout.addWidget(self.contour_slider)
- main_layout.addLayout(threshold_layout)
- # Splitter views
- splitter = QSplitter(Qt.Horizontal)
- self.pixel_count_label = QLabel("Суммарный объем (мл): 0.00")
- self.pixel_count_label.setStyleSheet("font-size: 16pt; font-weight: bold;")
- main_layout.addWidget(self.pixel_count_label)
- self.image_label = ImageLabel(self)
- self.image_label.setAlignment(Qt.AlignCenter)
- self.image_label.on_slice_changed = self._recompute_views
- img_frame = self._make_frame("", self.image_label)
- splitter.addWidget(img_frame)
- self.filtration_label = QLabel()
- self.filtration_label.setAlignment(Qt.AlignCenter)
- filt_frame = self._make_frame(" ", self.filtration_label)
- splitter.addWidget(filt_frame)
- self.segmentation_label = QLabel()
- self.segmentation_label.setAlignment(Qt.AlignCenter)
- seg_frame = self._make_frame("", self.segmentation_label)
- splitter.addWidget(seg_frame)
- self.image_label.setScaledContents(True)
- self.filtration_label.setScaledContents(True)
- self.segmentation_label.setScaledContents(True)
- splitter.setSizes([1, 1, 1])
- splitter.setStretchFactor(0, 1)
- splitter.setStretchFactor(1, 1)
- splitter.setStretchFactor(2, 1)
- main_layout.addWidget(splitter)
- # Buttons
- btns = QHBoxLayout()
- main_layout.addLayout(btns)
- load_btn = QPushButton('Load DICOM', self)
- load_btn.setFixedSize(150, 50)
- load_btn.setStyleSheet('QPushButton { font-size: 20px;}')
- load_btn.clicked.connect(self._load_folder)
- btns.addWidget(load_btn)
- next_btn = QPushButton('<<', self) # next slice (as in original)
- next_btn.setFixedSize(50, 50)
- next_btn.setStyleSheet('QPushButton { font-size: 20px; }')
- next_btn.clicked.connect(self.image_label.next_slice)
- btns.addWidget(next_btn)
- prev_btn = QPushButton('>>', self) # previous slice (as in original)
- prev_btn.setFixedSize(50, 50)
- prev_btn.setStyleSheet('QPushButton { font-size: 20px; }')
- prev_btn.clicked.connect(self.image_label.previous_slice)
- btns.addWidget(prev_btn)
- color_btn = QPushButton('Select Color', self)
- color_btn.setFixedSize(150, 50)
- color_btn.setStyleSheet('QPushButton { font-size: 20px; }')
- color_btn.clicked.connect(self._select_color)
- btns.addWidget(color_btn)
- label_btn = QPushButton('Set Label', self)
- label_btn.setFixedSize(120, 50)
- label_btn.setStyleSheet('QPushButton { font-size: 20px; }')
- label_btn.clicked.connect(self._set_label)
- btns.addWidget(label_btn)
- undo_btn = QPushButton('Undo', self)
- undo_btn.setFixedSize(100, 50)
- undo_btn.setStyleSheet('QPushButton { font-size: 20px; }')
- undo_btn.clicked.connect(self.image_label.undo)
- btns.addWidget(undo_btn)
- clear_btn = QPushButton('Clear slice', self)
- clear_btn.setFixedSize(140, 50)
- clear_btn.setStyleSheet('QPushButton { font-size: 20px; }')
- clear_btn.clicked.connect(self.image_label.clear_slice)
- btns.addWidget(clear_btn)
- seg_btn = QPushButton('Segmentation', self)
- seg_btn.setFixedSize(160, 50)
- seg_btn.setStyleSheet('QPushButton { font-size: 20px; }')
- seg_btn.clicked.connect(self._recompute_views)
- btns.addWidget(seg_btn)
- save_btn = QPushButton('Save ROIs', self)
- save_btn.setFixedSize(150, 50)
- save_btn.setStyleSheet('QPushButton { font-size: 20px; }')
- save_btn.clicked.connect(self._save_rois)
- btns.addWidget(save_btn)
- btns.addStretch(1)
- # React to sequence selection
- self.sequence_dropdown.currentIndexChanged.connect(self._on_sequence_changed)
- def _make_frame(self, title, widget):
- frame = QFrame()
- lay = QVBoxLayout()
- title_lbl = QLabel(f"<b>{title}</b>")
- title_lbl.setAlignment(Qt.AlignCenter)
- lay.addWidget(title_lbl)
- lay.addWidget(widget)
- frame.setLayout(lay)
- return frame
- # ---------- Series selection ----------
- def _init_series_filters(self):
- # Fill later from actual folder (unique SeriesDescription), but also provide common presets
- self.series_predicates = [
- ("Auto: All files", lambda ds: True),
- ("Contains 'Sag PD'", lambda ds: 'SeriesDescription' in ds and 'Sag PD' in str(ds.SeriesDescription)),
- ("Contains 'T1 Cube'", lambda ds: 'SeriesDescription' in ds and 'T1 Cube' in str(ds.SeriesDescription)),
- ]
- self.sequence_dropdown.clear()
- self.sequence_dropdown.addItems([name for name, _ in self.series_predicates])
- def _populate_series_from_folder(self):
- # read unique SeriesDescription and add them to dropdown
- seen = {}
- for f in self.image_label.all_dicom_files:
- try:
- ds = pydicom.dcmread(f, stop_before_pixels=True, force=True)
- desc = str(getattr(ds, 'SeriesDescription', ''))
- if desc and desc not in seen:
- seen[desc] = (desc, lambda D, d=desc: str(getattr(D, 'SeriesDescription', '')) == d)
- except Exception:
- pass
- # Append discovered series to the end
- for desc, (name, pred) in seen.items():
- self.series_predicates.append((f"Series: {name}", pred))
- self.sequence_dropdown.clear()
- self.sequence_dropdown.addItems([name for name, _ in self.series_predicates])
- # ---------- Actions ----------
- def _load_folder(self):
- folder = QFileDialog.getExistingDirectory(self, "Select DICOM Series Folder")
- if not folder:
- return
- try:
- self.image_label.load_folder(folder)
- self._populate_series_from_folder()
- # Auto-filter by current dropdown selection
- self._on_sequence_changed()
- except Exception as e:
- QMessageBox.critical(self, "Error", str(e))
- def _on_sequence_changed(self):
- idx = self.sequence_dropdown.currentIndex()
- if idx < 0 or idx >= len(self.series_predicates):
- return
- _, pred = self.series_predicates[idx]
- try:
- self.image_label.filter_series(pred)
- except Exception as e:
- QMessageBox.critical(self, "Error", str(e))
- def _toggle_z_source(self):
- self.use_slice_thickness = self.z_toggle.isChecked()
- self._recompute_views()
- def _update_thresholds(self):
- self.threshold_brightness = self.brightness_slider.value() / 100.0
- self.contours_thr = self.contour_slider.value() / 100.0
- self._recompute_views()
- def _select_color(self):
- col = QColorDialog.getColor()
- if col.isValid():
- self.image_label.set_current_color(col)
- def _set_label(self):
- label, ok = QInputDialog.getText(self, 'Set ROI Label', 'Enter label for ROI:')
- if ok and label:
- self.image_label.set_current_label(label)
- def _save_rois(self):
- path, _ = QFileDialog.getSaveFileName(self, "Save ROIs", "", "JSON Files (*.json);;All Files (*)")
- if not path:
- return
- roi_data = {}
- for slice_idx, rois in self.image_label.rois.items():
- roi_data[slice_idx] = []
- for roi in rois:
- roi_data[slice_idx].append({
- 'label': roi['label'],
- 'color': roi['color'].name(),
- 'points': [(p.x(), p.y()) for p in roi['points']],
- })
- with open(path, 'w', encoding='utf-8') as f:
- json.dump(roi_data, f, ensure_ascii=False, indent=2)
- # ---------- Processing ----------
- def _recompute_views(self):
- # filtration == contours on thresholded image within ROI mask (blue/red views in originals)
- self._apply_filtration()
- self._apply_segmentation()
- def _grab_gray_np(self):
- img = self.image_label.image
- if img.isNull():
- return None
- w, h = img.width(), img.height()
- ptr = img.bits()
- ptr.setsize(img.byteCount())
- arr = np.array(ptr).reshape(h, w, 1)
- return arr
- def _build_mask_from_rois(self, shape):
- mask = np.zeros(shape, dtype=np.uint8)
- rois = self.image_label.rois.get(self.image_label.slice_index, [])
- for roi in rois:
- pts = np.array([[p.x(), p.y()] for p in roi['points']], dtype=np.int32)
- if pts.shape[0] >= 3:
- cv2.fillPoly(mask, [pts], 255)
- return mask
- def _apply_filtration(self):
- arr = self._grab_gray_np()
- if arr is None:
- return
- h, w = arr.shape[:2]
- mask = self._build_mask_from_rois((h, w, 1))
- masked = arr.copy()
- masked[mask != 255] = 0
- thr_val = self.threshold_brightness * masked.max() if masked.max() > 0 else 0
- seg = np.zeros_like(arr)
- seg[(mask == 255) & (arr >= thr_val)] = 255
- contours, _ = cv2.findContours(seg.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- color = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
- cv2.drawContours(color, contours, -1, (255, 0, 0), 2)
- qimg = QImage(color.data, color.shape[1], color.shape[0], QImage.Format_RGB888)
- self.filtration_label.setPixmap(QPixmap.fromImage(qimg))
- def _apply_segmentation(self):
- arr = self._grab_gray_np()
- if arr is None:
- return
- h, w = arr.shape[:2]
- mask = self._build_mask_from_rois((h, w, 1))
- masked = arr.copy()
- masked[mask != 255] = 0
- thr_val = self.threshold_brightness * masked.max() if masked.max() > 0 else 0
- seg = np.zeros_like(arr)
- seg[(mask == 255) & (arr >= thr_val)] = 255
- contours, _ = cv2.findContours(seg.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- max_area = max((cv2.contourArea(c) for c in contours), default=0)
- area_thr = self.contours_thr * max_area
- filtered = [c for c in contours if cv2.contourArea(c) > area_thr]
- # Volume
- spacing_x, spacing_y = self.image_label.spacing_xy
- z = self.image_label.slice_thickness if self.use_slice_thickness else self.image_label.spacing_between_slices
- voxel_mm3 = float(spacing_x) * float(spacing_y) * float(z)
- pixel_area_sum = sum(cv2.contourArea(c) for c in filtered)
- volume_ml = pixel_area_sum * voxel_mm3 / 1000.0
- self.image_label.volume_by_slice[self.image_label.slice_index] = volume_ml
- total_ml = sum(self.image_label.volume_by_slice.values())
- self.pixel_count_label.setText(f"Суммарный объем (мл): {total_ml:.2f}")
- # Draw yellow filtered contours and per-slice text
- color = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
- cv2.drawContours(color, filtered, -1, (0, 255, 255), 2)
- cv2.putText(color,
- f"Slice: {self.image_label.slice_index} Volume (ml): {volume_ml:.3f}",
- (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 1)
- qimg = QImage(color.data, color.shape[1], color.shape[0], QImage.Format_RGB888)
- self.segmentation_label.setPixmap(QPixmap.fromImage(qimg))
- def main():
- app = QApplication(sys.argv)
- w = ROIDrawer()
- w.show()
- sys.exit(app.exec_())
- if __name__ == '__main__':
- main()
|