dicom_labeler.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. import sys
  2. import os
  3. import json
  4. import cv2
  5. import numpy as np
  6. from skimage.measure import find_contours
  7. from io import BytesIO
  8. import pydicom
  9. from PyQt5.QtWidgets import (
  10. QApplication, QMainWindow, QLabel, QPushButton, QColorDialog, QInputDialog,
  11. QFileDialog, QMessageBox, QVBoxLayout, QHBoxLayout, QWidget, QFrame, QSizePolicy, QComboBox
  12. )
  13. from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QColor
  14. from PyQt5.QtCore import Qt, QPoint, QRect
  15. from PyQt5.QtWidgets import QSplitter, QGroupBox, QVBoxLayout
  16. from PyQt5.QtWidgets import QSlider
  17. class ImageLabel(QLabel):
  18. def __init__(self, parent=None):
  19. super().__init__(parent)
  20. self.setMouseTracking(True)
  21. self.image = QImage()
  22. self.rois = {}
  23. self.current_roi = []
  24. self.drawing = False
  25. self.current_color = QColor(Qt.red)
  26. self.current_label = "ROI"
  27. self.undo_stack = []
  28. self.slice_index = 0
  29. self.dicom_files = []
  30. self.dataset = None
  31. self.mark_area_rect = QRect(10, 10, 30, 30)
  32. self.marked_by_rule_slices = set()
  33. # Zoom parameters
  34. self.zoom_factor = 1.0
  35. self.zoom_step = 0.1
  36. self.min_zoom = 0.1
  37. self.max_zoom = 5.0
  38. def load_dicom_series(self, folder_path):
  39. # Load all DICOM files from the folder
  40. self.dicom_files = [os.path.join(folder_path, f) for f in
  41. os.listdir(folder_path)] # if f.lower().endswith('.dcm')
  42. if not self.dicom_files:
  43. raise ValueError("No DICOM files"
  44. " found in the selected folder.")
  45. # Sort files by InstanceNumber or SliceLocation
  46. # self.dicom_files.sort(key=lambda f: int(pydicom.dcmread(f).InstanceNumber))
  47. self.slice_index = 0
  48. self.rois.clear()
  49. self.load_slice()
  50. def load_slice(self):
  51. if 0 <= self.slice_index < len(self.dicom_files):
  52. self.dataset = pydicom.dcmread(self.dicom_files[self.slice_index], force=True)
  53. # print(self.dataset['0008', '103e'])
  54. print(self.dataset)
  55. #
  56. # private_tag = self.dataset.get((0x0021, 0x1101))
  57. # data = private_tag.value
  58. # if isinstance(data, bytes):
  59. # text = data.decode("latin-1") # Siemens часто кодирует так
  60. # else:
  61. # text = str(data)
  62. # print(text)
  63. #
  64. # if 'pd_tse_fs_sag' in self.dataset['0008', '103e'][0:9]:
  65. pixel_array = self.dataset.pixel_array
  66. print(f'This file comprises {pixel_array.shape[0]} slices. File"s name is ', self.dataset['0008', '103e'],
  67. self.dicom_files[self.slice_index])
  68. image_normalized = cv2.normalize(pixel_array, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
  69. image_normalized = image_normalized.astype(np.uint8)
  70. # Normalize pixel array to 8-bit
  71. image = QImage(image_normalized.data, pixel_array.shape[1], pixel_array.shape[0],
  72. QImage.Format_Grayscale8)
  73. self.image = image
  74. self.setPixmap(QPixmap.fromImage(self.image))
  75. self.update()
  76. # else:
  77. # pass
  78. def next_slice(self):
  79. if self.slice_index < len(self.dicom_files) - 1:
  80. self.slice_index += 1
  81. self.load_slice()
  82. def previous_slice(self):
  83. if self.slice_index > 0:
  84. self.slice_index -= 1
  85. self.load_slice()
  86. def set_current_color(self, color):
  87. self.current_color = color
  88. def set_current_label(self, label):
  89. self.current_label = label
  90. def mousePressEvent(self, event):
  91. if event.button() == Qt.LeftButton:
  92. scaled_pos = event.pos() / self.zoom_factor
  93. if self.mark_area_rect.contains(scaled_pos):
  94. current_index = self.slice_index
  95. if current_index not in self.marked_by_rule_slices:
  96. self.mark_range_from_current()
  97. self.update()
  98. return
  99. def mouseMoveEvent(self, event):
  100. if self.drawing:
  101. self.current_roi.append(event.pos())
  102. self.update()
  103. def mouseReleaseEvent(self, event):
  104. if event.button() == Qt.LeftButton and self.drawing:
  105. self.drawing = False
  106. if len(self.current_roi) > 2:
  107. if self.slice_index not in self.rois:
  108. self.rois[self.slice_index] = []
  109. roi = {
  110. 'label': self.current_label,
  111. 'color': self.current_color,
  112. 'points': self.current_roi
  113. }
  114. self.rois[self.slice_index].append(roi)
  115. self.undo_stack.append((self.slice_index, roi))
  116. self.current_roi = []
  117. self.update()
  118. def paintEvent(self, event):
  119. super().paintEvent(event)
  120. if not self.image.isNull():
  121. painter = QPainter(self)
  122. # Apply zoom transformation
  123. transform = painter.transform()
  124. transform.scale(self.zoom_factor, self.zoom_factor)
  125. painter.setTransform(transform)
  126. painter.drawImage(self.rect(), self.image, self.image.rect())
  127. pen = QPen(Qt.red, 2, Qt.SolidLine)
  128. painter.setPen(pen)
  129. if self.slice_index in self.rois:
  130. for roi in self.rois[self.slice_index]:
  131. pen.setColor(roi['color'])
  132. painter.setPen(pen)
  133. points = [QPoint(p.x(), p.y()) for p in roi['points']]
  134. painter.drawPolygon(*points)
  135. if self.current_roi:
  136. pen.setColor(Qt.blue)
  137. painter.setPen(pen)
  138. points = [QPoint(p.x(), p.y()) for p in self.current_roi]
  139. painter.drawPolyline(*points)
  140. transform = painter.transform()
  141. transform.scale(self.zoom_factor, self.zoom_factor)
  142. painter.setTransform(transform)
  143. pen = QPen(Qt.green, 2)
  144. painter.setPen(pen)
  145. painter.setBrush(QColor(0, 255, 0, 100)) # полупрозрачный зеленый
  146. painter.drawRect(self.mark_area_rect)
  147. if self.slice_index in self.marked_by_rule_slices:
  148. painter.drawLine(self.mark_area_rect.topLeft() + QPoint(5, 10),
  149. self.mark_area_rect.center() + QPoint(0, 8))
  150. painter.drawLine(self.mark_area_rect.center() + QPoint(0, 8),
  151. self.mark_area_rect.topRight() + QPoint(-4, 4))
  152. painter.end()
  153. def undo(self):
  154. if self.undo_stack:
  155. slice_idx, last_roi = self.undo_stack.pop()
  156. if slice_idx in self.rois and last_roi in self.rois[slice_idx]:
  157. self.rois[slice_idx].remove(last_roi)
  158. if not self.rois[slice_idx]:
  159. del self.rois[slice_idx]
  160. self.update()
  161. def keyPressEvent(self, event):
  162. if event.button() == Qt.CTRL and Qt.Key_Z:
  163. self.undo()
  164. def save_rois(self, file_path):
  165. roi_data = {}
  166. for slice_idx, rois in self.rois.items():
  167. roi_data[slice_idx] = []
  168. for roi in rois:
  169. roi_data[slice_idx].append({
  170. 'label': roi['label'],
  171. 'color': roi['color'].name(),
  172. 'points': [(point.x(), point.y()) for point in roi['points']]
  173. })
  174. with open(file_path, 'w') as file:
  175. json.dump(roi_data, file, indent=4)
  176. def mark_range_from_current(self):
  177. current_index = self.slice_index
  178. total_slices = len(self.dicom_files)
  179. middle_index = total_slices // 2
  180. if current_index <= middle_index:
  181. indices_to_mark = range(0, current_index + 1)
  182. else:
  183. indices_to_mark = range(current_index, total_slices)
  184. for idx in indices_to_mark:
  185. roi = {
  186. 'label': 'AutoMark',
  187. 'color': QColor(Qt.green),
  188. 'points': [
  189. QPoint(10, 10), QPoint(40, 10),
  190. QPoint(40, 40), QPoint(10, 40)
  191. ]
  192. }
  193. if idx not in self.rois:
  194. self.rois[idx] = []
  195. self.rois[idx].append(roi)
  196. self.undo_stack.append((idx, roi))
  197. self.marked_by_rule_slices.add(idx) # Пометили
  198. self.update()
  199. def wheelEvent(self, event):
  200. if event.modifiers() & Qt.ControlModifier:
  201. # Zooming
  202. angle = event.angleDelta().y()
  203. factor = 1.1 if angle > 0 else 0.9
  204. self.zoom_factor *= factor
  205. self.update()
  206. else:
  207. # Slice navigation
  208. angle = event.angleDelta().y()
  209. if angle > 0:
  210. self.previous_slice()
  211. else:
  212. self.next_slice()
  213. def zoom_in(self):
  214. if self.zoom_factor < self.max_zoom:
  215. self.zoom_factor += self.zoom_step
  216. self.update()
  217. def zoom_out(self):
  218. if self.zoom_factor > self.min_zoom:
  219. self.zoom_factor -= self.zoom_step
  220. self.update()
  221. class ROIDrawer(QMainWindow):
  222. def __init__(self):
  223. super().__init__()
  224. self.initUI()
  225. self.threshold_brightness = 0.5 # Default values
  226. self.contours_thr = 0.3
  227. def initUI(self):
  228. self.setWindowTitle("KneeSeg")
  229. screen_geometry = QApplication.desktop().availableGeometry()
  230. screen_width = screen_geometry.width()
  231. screen_height = screen_geometry.height()
  232. # Set window to 70% of the screen size
  233. self.resize(int(screen_width * 0.9), int(screen_height * 0.9))
  234. self.setFixedSize(self.size())
  235. # Central widget
  236. central_widget = QWidget(self)
  237. self.setCentralWidget(central_widget)
  238. # Create threshold sliders
  239. self.brightness_slider = QSlider(Qt.Horizontal)
  240. self.brightness_slider.setMinimum(1) # 1 corresponds to 0.01
  241. self.brightness_slider.setMaximum(99) # 99 corresponds to 0.99
  242. self.brightness_slider.setValue(50) # Default to 0.50
  243. self.brightness_slider.setTickInterval(1)
  244. self.brightness_slider.valueChanged.connect(self.update_thresholds)
  245. self.contour_slider = QSlider(Qt.Horizontal)
  246. self.contour_slider.setMinimum(1)
  247. self.contour_slider.setMaximum(99)
  248. self.contour_slider.setValue(30) # Default to 0.30
  249. self.contour_slider.setTickInterval(1)
  250. self.contour_slider.valueChanged.connect(self.update_thresholds)
  251. # Выпадающий список последовательностей
  252. central_widget = QWidget(self)
  253. self.setCentralWidget(central_widget)
  254. self.sequence_dropdown = QComboBox()
  255. self.sequence_dropdown.setFixedSize(250, 30)
  256. self.sequence_dropdown.setStyleSheet("font-size: 14px;")
  257. # заглушки
  258. self.sequence_dropdown.addItems(["Sequence A", "Sequence B", "Sequence C"])
  259. # Add sliders to the UI layout
  260. threshold_layout = QVBoxLayout()
  261. brightness_label = QLabel("Порог яркости")
  262. brightness_label.setStyleSheet("font-size: 14px;font-weight: bold;")
  263. contour_label = QLabel("Порог площади")
  264. contour_label.setStyleSheet("font-size: 14px;font-weight: bold;")
  265. threshold_layout.addWidget(brightness_label)
  266. threshold_layout.addWidget(self.brightness_slider)
  267. threshold_layout.addWidget(contour_label)
  268. threshold_layout.addWidget(self.contour_slider)
  269. # threshold_layout = QVBoxLayout()
  270. # threshold_layout.addWidget(QLabel("Brightness Threshold"))
  271. # threshold_layout.addWidget(self.brightness_slider)
  272. # threshold_layout.addWidget(QLabel("Contour Threshold"))
  273. # threshold_layout.addWidget(self.contour_slider)
  274. main_layout = QVBoxLayout(central_widget)
  275. top_controls_layout = QHBoxLayout()
  276. top_controls_layout.addStretch(1) # прижать dropdown вправо
  277. top_controls_layout.addWidget(self.sequence_dropdown)
  278. main_layout.addLayout(top_controls_layout)
  279. main_layout.addLayout(threshold_layout)
  280. # top_layout = QHBoxLayout()
  281. # bottom_layout = QHBoxLayout()
  282. # Splitter
  283. splitter = QSplitter(Qt.Horizontal)
  284. # Image display with ROI drawing
  285. self.image_label = ImageLabel(self)
  286. self.image_label.setAlignment(Qt.AlignCenter)
  287. image_frame = self.create_labeled_frame("", self.image_label)
  288. splitter.addWidget(image_frame)
  289. # Filtration result display
  290. self.filtration_label = ImageLabel(self) #QLabel("Filtration results will be displayed here.")
  291. self.filtration_label.setAlignment(Qt.AlignCenter)
  292. filtration_frame = self.create_labeled_frame(" ", self.filtration_label)
  293. splitter.addWidget(filtration_frame)
  294. # Segmentation result display
  295. self.segmentation_label = ImageLabel(self) # QLabel("Segmentation results will be displayed here.")
  296. self.segmentation_label.setAlignment(Qt.AlignCenter)
  297. # segmentation_frame = self.create_labeled_frame("Image Segmentation ", self.segmentation_label)
  298. segmentation_frame = self.create_labeled_frame("", self.segmentation_label)
  299. splitter.addWidget(segmentation_frame)
  300. #Enable Scaled Contents for Each
  301. self.image_label.setScaledContents(True)
  302. self.filtration_label.setScaledContents(True)
  303. self.segmentation_label.setScaledContents(True)
  304. # # Set size policies to allow resizing
  305. self.image_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
  306. self.filtration_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
  307. self.segmentation_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
  308. #
  309. # # Set initial sizes for splitter sections to be equal
  310. splitter.setSizes([1, 1, 1])
  311. splitter.setStretchFactor(0, 1) # First widget
  312. splitter.setStretchFactor(1, 1) # Second widget
  313. splitter.setStretchFactor(2, 1) # Third widget
  314. # Buttons layout
  315. button_layout = QHBoxLayout()
  316. # Add splitter to the main layout
  317. main_layout.addWidget(splitter)
  318. main_layout.addLayout(button_layout)
  319. # Buttons
  320. load_button = QPushButton('Load DICOM ', self)
  321. load_button.setFixedSize(150, 50)
  322. load_button.setStyleSheet('QPushButton { font-size: 20px;}')
  323. load_button.clicked.connect(self.load_dicom_series)
  324. button_layout.addWidget(load_button)
  325. prev_button = QPushButton('>>', self) # Previous Slice
  326. prev_button.setFixedSize(50, 50)
  327. prev_button.setStyleSheet('QPushButton { font-size: 20px; }')
  328. prev_button.clicked.connect(self.image_label.previous_slice)
  329. button_layout.addWidget(prev_button)
  330. next_button = QPushButton('<<', self) # Next Slice
  331. next_button.setFixedSize(50, 50)
  332. next_button.setStyleSheet('QPushButton { font-size: 20px; }')
  333. next_button.clicked.connect(self.image_label.next_slice)
  334. button_layout.addWidget(next_button)
  335. color_button = QPushButton('Select Color', self)
  336. color_button.setFixedSize(150, 50)
  337. color_button.setStyleSheet('QPushButton { font-size: 20px; }')
  338. color_button.clicked.connect(self.select_color)
  339. button_layout.addWidget(color_button)
  340. label_button = QPushButton('Set Label', self)
  341. label_button.setFixedSize(100, 50)
  342. label_button.setStyleSheet('QPushButton { font-size: 20px; }')
  343. label_button.clicked.connect(self.set_label)
  344. button_layout.addWidget(label_button)
  345. undo_button = QPushButton('Undo', self)
  346. undo_button.setFixedSize(100, 50)
  347. undo_button.setStyleSheet('QPushButton { font-size: 20px; }')
  348. undo_button.clicked.connect(self.image_label.undo)
  349. button_layout.addWidget(undo_button)
  350. # Filtration and Segmentation Buttons
  351. # filtration_button = QPushButton(' Filtration', self)
  352. # filtration_button.setFixedSize(150, 50)
  353. # filtration_button.setStyleSheet('QPushButton { font-size: 20px; }')
  354. # filtration_button.clicked.connect(self.apply_filtration)
  355. # button_layout.addWidget(filtration_button)
  356. segmentation_button = QPushButton('Segmentation', self)
  357. segmentation_button.setFixedSize(150, 50)
  358. segmentation_button.setStyleSheet('QPushButton { font-size: 20px; }')
  359. segmentation_button.clicked.connect(lambda: (self.apply_segmentation(), self.apply_filtration()))
  360. button_layout.addWidget(segmentation_button)
  361. save_button = QPushButton('Save ROIs', self)
  362. save_button.setFixedSize(150, 50)
  363. save_button.setStyleSheet('QPushButton { font-size: 20px; }')
  364. save_button.clicked.connect(self.save_rois)
  365. button_layout.addWidget(save_button)
  366. button_layout.addStretch(1) # Push buttons to the top
  367. def update_thresholds(self):
  368. """Update segmentation thresholds based on slider values."""
  369. self.threshold_brightness = self.brightness_slider.value() / 100
  370. self.contours_thr = self.contour_slider.value() / 100
  371. def create_labeled_frame(self, title, widget):
  372. frame = QFrame()
  373. layout = QVBoxLayout()
  374. label = QLabel(f"<b>{title}</b>")
  375. label.setAlignment(Qt.AlignCenter)
  376. layout.addWidget(label)
  377. layout.addWidget(widget)
  378. frame.setLayout(layout)
  379. return frame
  380. def load_dicom_series(self):
  381. options = QFileDialog.Options()
  382. folder_path = QFileDialog.getExistingDirectory(
  383. self, "Select DICOM Series Folder", options=options)
  384. if folder_path:
  385. try:
  386. self.image_label.load_dicom_series(folder_path)
  387. except ValueError as e:
  388. QMessageBox.critical(self, "Error", str(e))
  389. def select_color(self):
  390. color = QColorDialog.getColor()
  391. if color.isValid():
  392. self.image_label.set_current_color(color)
  393. def set_label(self):
  394. label, ok = QInputDialog.getText(self, 'Set ROI Label', 'Enter label for ROI:')
  395. if ok and label:
  396. self.image_label.set_current_label(label)
  397. def save_rois(self):
  398. options = QFileDialog.Options()
  399. file_path, _ = QFileDialog.getSaveFileName(
  400. self, "Save ROIs", "", "JSON Files (*.json);;All Files (*)", options=options)
  401. if file_path:
  402. self.image_label.save_rois(file_path)
  403. def apply_filtration(self):
  404. if self.image_label.image.isNull():
  405. QMessageBox.warning(self, "Warning", "No image loaded.")
  406. return
  407. # Convert QImage to numpy array
  408. image = self.image_label.image
  409. width = image.width()
  410. height = image.height()
  411. ptr = image.bits()
  412. ptr.setsize(image.byteCount())
  413. arr = np.array(ptr).reshape(height, width, 1) # Assuming grayscale
  414. # Apply user-defined thresholds
  415. threshold_brightness = self.threshold_brightness * np.max(arr)
  416. segmented_im = np.copy(arr)
  417. segmented_im[segmented_im < threshold_brightness] = 0
  418. segmented_im[segmented_im >= threshold_brightness] = 255 # Binary mask
  419. # Find contours in the binary mask
  420. contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  421. # Convert grayscale image to BGR for color overlay
  422. color_image = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
  423. # Draw contours on the color image
  424. cv2.drawContours(color_image, contours, -1, (255, 0, 0), 2) # Red color, thickness 2
  425. # Convert back to QImage
  426. q_image = QImage(color_image.data, color_image.shape[1], color_image.shape[0], QImage.Format_RGB888)
  427. self.filtration_label.setPixmap(QPixmap.fromImage(q_image))
  428. def apply_segmentation(self):
  429. if self.image_label.image.isNull():
  430. QMessageBox.warning(self, "Warning", "No image loaded.")
  431. return
  432. # Convert QImage to numpy array
  433. image = self.image_label.image
  434. width, height = image.width(), image.height()
  435. ptr = image.bits()
  436. ptr.setsize(image.byteCount())
  437. arr = np.array(ptr).reshape(height, width, 1) # Assuming grayscale
  438. # Apply user-defined thresholds
  439. threshold_brightness = self.threshold_brightness * np.max(arr)
  440. segmented_im = np.copy(arr)
  441. segmented_im[segmented_im < threshold_brightness] = 0
  442. segmented_im[segmented_im >= threshold_brightness] = 255 # Binary mask
  443. # Find contours
  444. # contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  445. contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
  446. contours_square = [cv2.contourArea(cnt) for cnt in contours]
  447. # Apply user-defined contour threshold
  448. if contours_square:
  449. contours_thr = self.contours_thr * np.max(contours_square)
  450. else:
  451. contours_thr = 0
  452. filtered_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > contours_thr]
  453. # Convert grayscale image to BGR for color overlay
  454. color_image = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
  455. cv2.drawContours(color_image, filtered_contours, -1, (0, 255, 255), 2) # Draw yellow contours
  456. # Convert back to QImage
  457. q_image = QImage(color_image.data, color_image.shape[1], color_image.shape[0], QImage.Format_RGB888)
  458. self.segmentation_label.setPixmap(QPixmap.fromImage(q_image))
  459. if __name__ == '__main__':
  460. app = QApplication(sys.argv)
  461. drawer = ROIDrawer()
  462. drawer.show()
  463. sys.exit(app.exec_())