瀏覽代碼

initial adding project files

spacexerq 3 周之前
父節點
當前提交
561514986d
共有 8 個文件被更改,包括 1399 次插入0 次删除
  1. 1 0
      .gitignore
  2. 477 0
      src/dicom_labeler.py
  3. 80 0
      src/draw.py
  4. 89 0
      src/draw_multi.py
  5. 154 0
      src/imgsegclassic.py
  6. 178 0
      src/labeler.py
  7. 67 0
      src/read_data.py
  8. 353 0
      src/shape_draw.py

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+/DATA/

+ 477 - 0
src/dicom_labeler.py

@@ -0,0 +1,477 @@
+import sys
+import os
+import json
+import cv2
+import numpy as np
+from skimage.measure import find_contours
+
+import pydicom
+from PyQt5.QtWidgets import (
+    QApplication, QMainWindow, QLabel, QPushButton, QColorDialog, QInputDialog,
+    QFileDialog, QMessageBox, QVBoxLayout, QHBoxLayout, QWidget, QFrame, QSizePolicy
+)
+from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QColor
+from PyQt5.QtCore import Qt, QPoint
+from PyQt5.QtWidgets import QSplitter, QGroupBox, QVBoxLayout
+from PyQt5.QtWidgets import QSlider
+
+
+class ImageLabel(QLabel):
+    def __init__(self, parent=None):
+        super().__init__(parent)
+        self.setMouseTracking(True)
+        self.image = QImage()
+        self.rois = {}
+        self.current_roi = []
+        self.drawing = False
+        self.current_color = QColor(Qt.red)
+        self.current_label = "ROI"
+        self.undo_stack = []
+        self.slice_index = 0
+        self.dicom_files = []
+        self.dataset = None
+
+        # Zoom parameters
+        self.zoom_factor = 1.0
+        self.zoom_step = 0.1
+        self.min_zoom = 0.1
+        self.max_zoom = 5.0
+
+    def load_dicom_series(self, folder_path):
+        # Load all DICOM files from the folder
+        self.dicom_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path)] # if f.lower().endswith('.dcm')
+        if not self.dicom_files:
+            raise ValueError("No DICOM files found in the selected folder.")
+
+        # Sort files by InstanceNumber or SliceLocation
+        # self.dicom_files.sort(key=lambda f: int(pydicom.dcmread(f).InstanceNumber))
+
+        self.slice_index = 0
+        self.rois.clear()
+        self.load_slice()
+
+    def load_slice(self):
+        if 0 <= self.slice_index < len(self.dicom_files):
+            self.dataset = pydicom.dcmread(self.dicom_files[self.slice_index], force=True)
+            print(self.dataset['0008', '103e'])
+            # if 'pd_tse_fs_sag' in self.dataset['0008', '103e'][0:9]:
+            pixel_array = self.dataset.pixel_array
+            print(f'This file comprises {pixel_array.shape[0]} slices. File"s name is ', self.dataset['0008', '103e'],self.dicom_files[self.slice_index])
+            image_normalized = cv2.normalize(pixel_array, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
+            image_normalized = image_normalized.astype(np.uint8)
+
+            # Normalize pixel array to 8-bit
+            image = QImage(image_normalized.data, pixel_array.shape[1], pixel_array.shape[0],
+                           QImage.Format_Grayscale8)
+            self.image = image
+            self.setPixmap(QPixmap.fromImage(self.image))
+            self.update()
+            # else:
+            #     pass
+
+
+    def next_slice(self):
+        if self.slice_index < len(self.dicom_files) - 1:
+            self.slice_index += 1
+            self.load_slice()
+
+
+    def previous_slice(self):
+        if self.slice_index > 0:
+            self.slice_index -= 1
+            self.load_slice()
+
+    def set_current_color(self, color):
+        self.current_color = color
+
+    def set_current_label(self, label):
+        self.current_label = label
+
+    def mousePressEvent(self, event):
+        if event.button() == Qt.LeftButton:
+            self.drawing = True
+            self.current_roi = [event.pos()]
+            self.update()
+
+    def mouseMoveEvent(self, event):
+        if self.drawing:
+            self.current_roi.append(event.pos())
+            self.update()
+
+    def mouseReleaseEvent(self, event):
+        if event.button() == Qt.LeftButton and self.drawing:
+            self.drawing = False
+            if len(self.current_roi) > 2:
+                if self.slice_index not in self.rois:
+                    self.rois[self.slice_index] = []
+                roi = {
+                    'label': self.current_label,
+                    'color': self.current_color,
+                    'points': self.current_roi
+                }
+                self.rois[self.slice_index].append(roi)
+                self.undo_stack.append((self.slice_index, roi))
+            self.current_roi = []
+            self.update()
+
+    def paintEvent(self, event):
+        super().paintEvent(event)
+        if not self.image.isNull():
+            painter = QPainter(self)
+            # Apply zoom transformation
+            transform = painter.transform()
+            transform.scale(self.zoom_factor, self.zoom_factor)
+            painter.setTransform(transform)
+
+            painter.drawImage(self.rect(), self.image, self.image.rect())
+            pen = QPen(Qt.red, 2, Qt.SolidLine)
+            painter.setPen(pen)
+            if self.slice_index in self.rois:
+                for roi in self.rois[self.slice_index]:
+                    pen.setColor(roi['color'])
+                    painter.setPen(pen)
+                    points = [QPoint(p.x(), p.y()) for p in roi['points']]
+                    painter.drawPolygon(*points)
+            if self.current_roi:
+                pen.setColor(Qt.blue)
+                painter.setPen(pen)
+                points = [QPoint(p.x(), p.y()) for p in self.current_roi]
+                painter.drawPolyline(*points)
+            painter.end()
+
+    def undo(self):
+        if self.undo_stack:
+            slice_idx, last_roi = self.undo_stack.pop()
+            if slice_idx in self.rois and last_roi in self.rois[slice_idx]:
+                self.rois[slice_idx].remove(last_roi)
+                if not self.rois[slice_idx]:
+                    del self.rois[slice_idx]
+            self.update()
+
+    def save_rois(self, file_path):
+        roi_data = {}
+        for slice_idx, rois in self.rois.items():
+            roi_data[slice_idx] = []
+            for roi in rois:
+                roi_data[slice_idx].append({
+                    'label': roi['label'],
+                    'color': roi['color'].name(),
+                    'points': [(point.x(), point.y()) for point in roi['points']]
+                })
+        with open(file_path, 'w') as file:
+            json.dump(roi_data, file, indent=4)
+
+
+
+    def wheelEvent(self, event):
+        if event.modifiers() & Qt.ControlModifier:
+            # Zooming
+            angle = event.angleDelta().y()
+            factor = 1.1 if angle > 0 else 0.9
+            self.zoom_factor *= factor
+            self.update()
+        else:
+            # Slice navigation
+            angle = event.angleDelta().y()
+            if angle > 0:
+                self.previous_slice()
+            else:
+                self.next_slice()
+
+    def zoom_in(self):
+        if self.zoom_factor < self.max_zoom:
+            self.zoom_factor += self.zoom_step
+            self.update()
+
+    def zoom_out(self):
+        if self.zoom_factor > self.min_zoom:
+            self.zoom_factor -= self.zoom_step
+            self.update()
+
+
+class ROIDrawer(QMainWindow):
+    def __init__(self):
+        super().__init__()
+        self.initUI()
+        self.threshold_brightness = 0.5  # Default values
+        self.contours_thr = 0.3
+
+    def initUI(self):
+        self.setWindowTitle("KneeSeg")
+        screen_geometry = QApplication.desktop().availableGeometry()
+        screen_width = screen_geometry.width()
+        screen_height = screen_geometry.height()
+        # Set window to 70% of the screen size
+        self.resize(int(screen_width * 0.9), int(screen_height * 0.9))
+        self.setFixedSize(self.size())
+
+        # Central widget
+        central_widget = QWidget(self)
+        self.setCentralWidget(central_widget)
+
+        # Create threshold sliders
+        self.brightness_slider = QSlider(Qt.Horizontal)
+        self.brightness_slider.setMinimum(1)  # 1 corresponds to 0.01
+        self.brightness_slider.setMaximum(99)  # 99 corresponds to 0.99
+        self.brightness_slider.setValue(50)  # Default to 0.50
+        self.brightness_slider.setTickInterval(1)
+        self.brightness_slider.valueChanged.connect(self.update_thresholds)
+
+        self.contour_slider = QSlider(Qt.Horizontal)
+        self.contour_slider.setMinimum(1)
+        self.contour_slider.setMaximum(99)
+        self.contour_slider.setValue(30)  # Default to 0.30
+        self.contour_slider.setTickInterval(1)
+        self.contour_slider.valueChanged.connect(self.update_thresholds)
+
+        # Add sliders to the UI layout
+        threshold_layout = QVBoxLayout()
+        brightness_label = QLabel("Порог яркости")
+        brightness_label.setStyleSheet("font-size: 14px;font-weight: bold;")
+
+        contour_label = QLabel("Порог площади")
+        contour_label.setStyleSheet("font-size: 14px;font-weight: bold;")
+
+        threshold_layout.addWidget(brightness_label)
+        threshold_layout.addWidget(self.brightness_slider)
+        threshold_layout.addWidget(contour_label)
+        threshold_layout.addWidget(self.contour_slider)
+
+        # threshold_layout = QVBoxLayout()
+        # threshold_layout.addWidget(QLabel("Brightness Threshold"))
+        # threshold_layout.addWidget(self.brightness_slider)
+        # threshold_layout.addWidget(QLabel("Contour Threshold"))
+        # threshold_layout.addWidget(self.contour_slider)
+
+        # Layouts
+        main_layout = QVBoxLayout(central_widget)
+        main_layout.addLayout(threshold_layout)
+
+        # top_layout = QHBoxLayout()
+        # bottom_layout = QHBoxLayout()
+
+        # Splitter
+        splitter = QSplitter(Qt.Horizontal)
+
+        # Image display with ROI drawing
+        self.image_label = ImageLabel(self)
+        self.image_label.setAlignment(Qt.AlignCenter)
+        image_frame = self.create_labeled_frame("", self.image_label)
+        splitter.addWidget(image_frame)
+
+        # Filtration result display
+        self.filtration_label = ImageLabel(self) #QLabel("Filtration results will be displayed here.")
+        self.filtration_label.setAlignment(Qt.AlignCenter)
+        filtration_frame = self.create_labeled_frame(" ", self.filtration_label)
+        splitter.addWidget(filtration_frame)
+
+        # Segmentation result display
+        self.segmentation_label = ImageLabel(self) # QLabel("Segmentation results will be displayed here.")
+        self.segmentation_label.setAlignment(Qt.AlignCenter)
+        # segmentation_frame = self.create_labeled_frame("Image Segmentation ", self.segmentation_label)
+        segmentation_frame = self.create_labeled_frame("", self.segmentation_label)
+
+        splitter.addWidget(segmentation_frame)
+
+        #Enable Scaled Contents for Each
+        self.image_label.setScaledContents(True)
+        self.filtration_label.setScaledContents(True)
+        self.segmentation_label.setScaledContents(True)
+
+        # # Set size policies to allow resizing
+        self.image_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
+        self.filtration_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
+        self.segmentation_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
+        #
+        # # Set initial sizes for splitter sections to be equal
+        splitter.setSizes([1, 1, 1])
+        splitter.setStretchFactor(0, 1)  # First widget
+        splitter.setStretchFactor(1, 1)  # Second widget
+        splitter.setStretchFactor(2, 1)  # Third widget
+
+        # Buttons layout
+        button_layout = QHBoxLayout()
+
+        # Add splitter to the main layout
+        main_layout.addWidget(splitter)
+        main_layout.addLayout(button_layout)
+
+        # Buttons
+        load_button = QPushButton('Load DICOM ', self)
+        load_button.setFixedSize(150, 50)
+        load_button.setStyleSheet('QPushButton { font-size: 20px;}')
+        load_button.clicked.connect(self.load_dicom_series)
+        button_layout.addWidget(load_button)
+
+        color_button = QPushButton('Select Color', self)
+        color_button.setFixedSize(150, 50)
+        color_button.setStyleSheet('QPushButton { font-size: 20px; }')
+        color_button.clicked.connect(self.select_color)
+        button_layout.addWidget(color_button)
+
+        label_button = QPushButton('Set Label', self)
+        label_button.setFixedSize(100, 50)
+        label_button.setStyleSheet('QPushButton { font-size: 20px; }')
+        label_button.clicked.connect(self.set_label)
+        button_layout.addWidget(label_button)
+
+        undo_button = QPushButton('Undo', self)
+        undo_button.setFixedSize(100, 50)
+        undo_button.setStyleSheet('QPushButton { font-size: 20px; }')
+        undo_button.clicked.connect(self.image_label.undo)
+        button_layout.addWidget(undo_button)
+
+        prev_button = QPushButton('>>', self) # Previous Slice
+        prev_button.setFixedSize(50, 50)
+        prev_button.setStyleSheet('QPushButton { font-size: 20px; }')
+        prev_button.clicked.connect(self.image_label.previous_slice)
+        button_layout.addWidget(prev_button)
+
+        next_button = QPushButton('<<', self) # Next Slice
+        next_button.setFixedSize(50, 50)
+        next_button.setStyleSheet('QPushButton { font-size: 20px; }')
+        next_button.clicked.connect(self.image_label.next_slice)
+        button_layout.addWidget(next_button)
+
+       # Filtration and Segmentation Buttons
+       #  filtration_button = QPushButton(' Filtration', self)
+       #  filtration_button.setFixedSize(150, 50)
+       #  filtration_button.setStyleSheet('QPushButton { font-size: 20px; }')
+       #  filtration_button.clicked.connect(self.apply_filtration)
+       #  button_layout.addWidget(filtration_button)
+
+        segmentation_button = QPushButton('Segmentation', self)
+        segmentation_button.setFixedSize(150, 50)
+        segmentation_button.setStyleSheet('QPushButton { font-size: 20px; }')
+        segmentation_button.clicked.connect(lambda: (self.apply_segmentation(), self.apply_filtration()))
+        button_layout.addWidget(segmentation_button)
+
+        save_button = QPushButton('Save ROIs', self)
+        save_button.setFixedSize(150, 50)
+        save_button.setStyleSheet('QPushButton { font-size: 20px; }')
+        save_button.clicked.connect(self.save_rois)
+        button_layout.addWidget(save_button)
+
+        button_layout.addStretch(1)  # Push buttons to the top
+
+    def update_thresholds(self):
+        """Update segmentation thresholds based on slider values."""
+        self.threshold_brightness = self.brightness_slider.value() / 100
+        self.contours_thr = self.contour_slider.value() / 100
+
+    def create_labeled_frame(self, title, widget):
+        frame = QFrame()
+        layout = QVBoxLayout()
+        label = QLabel(f"<b>{title}</b>")
+        label.setAlignment(Qt.AlignCenter)
+        layout.addWidget(label)
+        layout.addWidget(widget)
+        frame.setLayout(layout)
+        return frame
+    def load_dicom_series(self):
+        options = QFileDialog.Options()
+        folder_path = QFileDialog.getExistingDirectory(
+            self, "Select DICOM Series Folder", options=options)
+        if folder_path:
+            try:
+                self.image_label.load_dicom_series(folder_path)
+                a = 10
+            except ValueError as e:
+                QMessageBox.critical(self, "Error", str(e))
+
+    def select_color(self):
+        color = QColorDialog.getColor()
+        if color.isValid():
+            self.image_label.set_current_color(color)
+
+    def set_label(self):
+        label, ok = QInputDialog.getText(self, 'Set ROI Label', 'Enter label for ROI:')
+        if ok and label:
+            self.image_label.set_current_label(label)
+
+    def save_rois(self):
+        options = QFileDialog.Options()
+        file_path, _ = QFileDialog.getSaveFileName(
+            self, "Save ROIs", "", "JSON Files (*.json);;All Files (*)", options=options)
+        if file_path:
+            self.image_label.save_rois(file_path)
+
+
+    def apply_filtration(self):
+        if self.image_label.image.isNull():
+            QMessageBox.warning(self, "Warning", "No image loaded.")
+            return
+
+        # Convert QImage to numpy array
+        image = self.image_label.image
+        width = image.width()
+        height = image.height()
+        ptr = image.bits()
+        ptr.setsize(image.byteCount())
+        arr = np.array(ptr).reshape(height, width, 1)  # Assuming grayscale
+
+        # Apply user-defined thresholds
+        threshold_brightness = self.threshold_brightness * np.max(arr)
+        segmented_im = np.copy(arr)
+        segmented_im[segmented_im < threshold_brightness] = 0
+        segmented_im[segmented_im >= threshold_brightness] = 255  # Binary mask
+
+        # Find contours in the binary mask
+        contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+        # Convert grayscale image to BGR for color overlay
+        color_image = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
+
+        # Draw contours on the color image
+        cv2.drawContours(color_image, contours, -1, (255, 0, 0), 2)  # Red color, thickness 2
+
+        # Convert back to QImage
+        q_image = QImage(color_image.data, color_image.shape[1], color_image.shape[0], QImage.Format_RGB888)
+        self.filtration_label.setPixmap(QPixmap.fromImage(q_image))
+
+    def apply_segmentation(self):
+        if self.image_label.image.isNull():
+            QMessageBox.warning(self, "Warning", "No image loaded.")
+            return
+
+        # Convert QImage to numpy array
+        image = self.image_label.image
+        width, height = image.width(), image.height()
+        ptr = image.bits()
+        ptr.setsize(image.byteCount())
+        arr = np.array(ptr).reshape(height, width, 1)  # Assuming grayscale
+
+        # Apply user-defined thresholds
+        threshold_brightness = self.threshold_brightness * np.max(arr)
+        segmented_im = np.copy(arr)
+        segmented_im[segmented_im < threshold_brightness] = 0
+        segmented_im[segmented_im >= threshold_brightness] = 255  # Binary mask
+
+        # Find contours
+        contours, _ = cv2.findContours(segmented_im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+        contours_square = [cv2.contourArea(cnt) for cnt in contours]
+
+        # Apply user-defined contour threshold
+        if contours_square:
+            contours_thr = self.contours_thr * np.max(contours_square)
+        else:
+            contours_thr = 0
+
+        filtered_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > contours_thr]
+
+        # Convert grayscale image to BGR for color overlay
+        color_image = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
+        cv2.drawContours(color_image, filtered_contours, -1, (0, 255, 255), 2)  # Draw yellow contours
+
+        # Convert back to QImage
+        q_image = QImage(color_image.data, color_image.shape[1], color_image.shape[0], QImage.Format_RGB888)
+        self.segmentation_label.setPixmap(QPixmap.fromImage(q_image))
+
+
+
+if __name__ == '__main__':
+    app = QApplication(sys.argv)
+    drawer = ROIDrawer()
+    drawer.show()
+    sys.exit(app.exec_())
+

+ 80 - 0
src/draw.py

@@ -0,0 +1,80 @@
+import matplotlib.pyplot as plt
+import pydicom
+import cv2
+import numpy as np
+from skimage.measure import find_contours
+
+
+# def get_mask_contour(mask):
+#     mskray = mask.astype(np.uint8)
+#     edged_mask = cv2.Canny(mskray, np.min(mskray), np.max(mskray))
+#     return edged_mask
+
+# Global variables
+drawing = False  # True when the mouse button is pressed
+points = []  # Stores points of the ROI
+mask = None  # To store the final mask
+
+def dcm2cv (dicom_img):
+    # Normalize the pixel values to the range 0-255
+    image_normalized = cv2.normalize(dicom_img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
+    # Convert to unsigned 8-bit integer type
+    image_normalized = image_normalized.astype(np.uint8)
+    cv_image = cv2.merge([image_normalized, image_normalized, image_normalized])
+    return cv_image
+# Mouse callback function
+def draw_roi(event, x, y, flags, param):
+    global drawing, points, mask
+
+    if event == cv2.EVENT_LBUTTONDOWN:  # Start drawing
+        drawing = True
+        points.append((x, y))
+
+    elif event == cv2.EVENT_MOUSEMOVE and drawing:  # Add points while moving
+        points.append((x, y))
+
+    elif event == cv2.EVENT_LBUTTONUP:  # Stop drawing
+        drawing = False
+
+    elif event == cv2.EVENT_RBUTTONDOWN:  # Right-click to complete the ROI
+        print('Drawing Complete')
+        if len(points) > 2:
+            mask = np.zeros_like(image, dtype=np.uint8)
+            pts = np.array(points, np.int32).reshape((-1, 1, 2))
+            cv2.fillPoly(mask, [pts], (255, 255, 255))
+            contours = find_contours(mask[:, :, 0], level=0.5)
+            # plt.imshow(image)
+            # for cn in range(len(contours)):
+            #     plt.plot(contours[cn][:, 1], contours[cn][:, 0], 'r')
+            # plt.show()
+
+
+# Load an image (replace this with actual MR image loading)
+path = r'E:\projects\knee_seg\data_tamplet\PA0\ST0\SE4\IM40'
+file_dcm = pydicom.read_file(path)
+image = file_dcm.pixel_array
+image = dcm2cv(image)
+
+
+cv2.namedWindow("Image")
+cv2.setMouseCallback("Image", draw_roi)
+
+while True:
+    temp_img = image.copy()
+    print(temp_img.shape)
+    if points:
+        cv2.polylines(temp_img, [np.array(points)], isClosed=False, color=(255, 0, 0), thickness=2)
+    cv2.imshow("Image", temp_img)
+
+    key = cv2.waitKey(1) & 0xFF
+    if key == 13 and mask is not None:  # Press 'enter' to save the mask
+        cv2.imwrite("mask.png", mask)
+        print("Mask saved as mask.png")
+        break
+    elif key == 27:  # Press 'Esc' to exit
+        # print(mask)
+        cv2.imwrite("mask.png", mask)
+        print("Break: Mask saved  ")
+        break
+
+cv2.destroyAllWindows()

+ 89 - 0
src/draw_multi.py

@@ -0,0 +1,89 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.patches import Polygon
+from matplotlib.widgets import LassoSelector
+from matplotlib.path import Path
+import json
+import pydicom
+
+class ROIManager:
+    def __init__(self, ax, image):
+        self.ax = ax
+        self.image = image
+        self.rois = []
+        self.current_poly = None
+        self.history = []
+        self.colors = plt.cm.get_cmap('tab10')
+        self.color_index = 0
+
+        self.ax.imshow(self.image, cmap='gray')
+        self.lasso = LassoSelector(ax, onselect=self.on_select)
+        self.cid = self.ax.figure.canvas.mpl_connect('key_press_event', self.on_key_press)
+
+    def on_select(self, verts):
+        path = Path(verts)
+        color = self.colors(self.color_index % 10)
+        self.color_index += 1
+        patch = Polygon(verts, closed=True, edgecolor=color, facecolor='none', label=f'ROI {len(self.rois) + 1}')
+        self.rois.append({'path': path, 'patch': patch, 'verts': verts})
+        self.ax.add_patch(patch)
+        self.history.append(('add', patch))
+        self.ax.legend()
+        self.ax.figure.canvas.draw_idle()
+
+    def on_key_press(self, event):
+        if event.key == 'u':  # Undo
+            if self.history:
+                action, patch = self.history.pop()
+                if action == 'add':
+                    patch.remove()
+                    self.rois = [roi for roi in self.rois if roi['patch'] != patch]
+                elif action == 'remove':
+                    self.ax.add_patch(patch)
+                    self.rois.append({'path': patch.get_path(), 'patch': patch})
+                self.ax.legend()
+                self.ax.figure.canvas.draw_idle()
+        elif event.key == 'd':  # Delete selected ROI
+            if self.current_poly:
+                confirm = input(f"Do you want to delete {self.current_poly.get_label()}? (y/n): ")
+                if confirm.lower() == 'y':
+                    self.current_poly.remove()
+                    self.rois = [roi for roi in self.rois if roi['patch'] != self.current_poly]
+                    self.history.append(('remove', self.current_poly))
+                    self.current_poly = None
+                    self.ax.legend()
+                    self.ax.figure.canvas.draw_idle()
+        elif event.key == 's':  # Save ROIs
+            self.save_rois()
+
+    def select_roi(self, event):
+        for roi in self.rois:
+            if roi['path'].contains_point((event.xdata, event.ydata)):
+                self.current_poly = roi['patch']
+                break
+
+    def save_rois(self):
+        roi_data = [{'label': roi['patch'].get_label(), 'vertices': roi['verts']} for roi in self.rois]
+        with open('rois.json', 'w') as f:
+            json.dump(roi_data, f)
+        print("ROIs saved to rois.json")
+    def get_rois_data(self):
+        roi_data = [{'label': roi['patch'].get_label(), 'vertices': roi['verts']} for roi in self.rois]
+        return roi_data
+
+
+def main():
+    image = np.zeros((512, 512))  # Placeholder for MR image
+    # Load an image (replace this with actual MR image loading)
+    path = r'C:\Users\user\Desktop\knee_seg\LITVYAK_D.I\LITVYAK_D.I\2025-01-22 181006\IMG-0001-00001.dcm'
+    file_dcm = pydicom.dcmread(path)
+    image = file_dcm.pixel_array
+
+    fig, ax = plt.subplots()
+    roi_manager = ROIManager(ax, image)
+    fig.canvas.mpl_connect('button_press_event', roi_manager.select_roi)
+    plt.show()
+    print(roi_manager.get_rois_data())
+
+if __name__ == "__main__":
+    main()

+ 154 - 0
src/imgsegclassic.py

@@ -0,0 +1,154 @@
+import sys
+import cv2
+import numpy as np
+from PyQt5.QtWidgets import (
+    QApplication, QMainWindow, QLabel, QVBoxLayout, QHBoxLayout, QPushButton,
+    QSlider, QFileDialog, QWidget, QComboBox, QSpinBox
+)
+from PyQt5.QtCore import Qt
+from PyQt5.QtGui import QPixmap, QImage
+from skimage.segmentation import active_contour
+from scipy import ndimage as ndi
+from skimage.filters import threshold_otsu
+from skimage.segmentation import watershed
+from skimage.feature import peak_local_max
+
+class ImageSegmentationApp(QMainWindow):
+    def __init__(self):
+        super().__init__()
+        self.setWindowTitle("Image Segmentation Explorer")
+        self.image = None
+        self.segmented_image = None
+
+        self.initUI()
+
+    def initUI(self):
+        # Main layout
+        main_layout = QVBoxLayout()
+
+        # Image display
+        self.image_label = QLabel()
+        self.image_label.setAlignment(Qt.AlignCenter)
+        main_layout.addWidget(self.image_label)
+
+        # Controls layout
+        controls_layout = QHBoxLayout()
+
+        # Load image button
+        load_button = QPushButton("Load Image")
+        load_button.clicked.connect(self.load_image)
+        controls_layout.addWidget(load_button)
+
+        # Segmentation method selection
+        self.method_combo = QComboBox()
+        self.method_combo.addItems([
+            "Thresholding", "Active Contours", "Region-Based", "Watershed"
+        ])
+        self.method_combo.currentIndexChanged.connect(self.update_parameters)
+        controls_layout.addWidget(self.method_combo)
+
+        # Parameter controls
+        self.param_label = QLabel("Parameter:")
+        controls_layout.addWidget(self.param_label)
+
+        self.param_slider = QSlider(Qt.Horizontal)
+        self.param_slider.setMinimum(0)
+        self.param_slider.setMaximum(255)
+        self.param_slider.setValue(128)
+        self.param_slider.valueChanged.connect(self.apply_segmentation)
+        controls_layout.addWidget(self.param_slider)
+
+        self.param_spinbox = QSpinBox()
+        self.param_spinbox.setMinimum(0)
+        self.param_spinbox.setMaximum(255)
+        self.param_spinbox.setValue(128)
+        self.param_spinbox.valueChanged.connect(self.param_slider.setValue)
+        self.param_slider.valueChanged.connect(self.param_spinbox.setValue)
+        controls_layout.addWidget(self.param_spinbox)
+
+        main_layout.addLayout(controls_layout)
+
+        # Apply segmentation button
+        apply_button = QPushButton("Apply Segmentation")
+        apply_button.clicked.connect(self.apply_segmentation)
+        main_layout.addWidget(apply_button)
+
+        # Set central widget
+        container = QWidget()
+        container.setLayout(main_layout)
+        self.setCentralWidget(container)
+
+    def load_image(self):
+        options = QFileDialog.Options()
+        file_name, _ = QFileDialog.getOpenFileName(
+            self, "Open Image File", "", "Images (*.png *.jpg *.bmp)", options=options
+        )
+        if file_name:
+            self.image = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE)
+            self.display_image(self.image)
+
+    def display_image(self, image):
+        height, width = image.shape
+        bytes_per_line = width
+        q_image = QImage(image.data, width, height, bytes_per_line, QImage.Format_Grayscale8)
+        pixmap = QPixmap.fromImage(q_image)
+        self.image_label.setPixmap(pixmap.scaled(
+            self.image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation
+        ))
+
+    def update_parameters(self):
+        method = self.method_combo.currentText()
+        if method == "Thresholding":
+            self.param_label.setText("Threshold:")
+            self.param_slider.setMaximum(255)
+            self.param_slider.setValue(128)
+        elif method == "Active Contours":
+            self.param_label.setText("Iterations:")
+            self.param_slider.setMaximum(500)
+            self.param_slider.setValue(100)
+        elif method == "Region-Based":
+            self.param_label.setText("Sigma:")
+            self.param_slider.setMaximum(10)
+            self.param_slider.setValue(2)
+        elif method == "Watershed":
+            self.param_label.setText("Markers:")
+            self.param_slider.setMaximum(10)
+            self.param_slider.setValue(2)
+
+    def apply_segmentation(self):
+        if self.image is None:
+            return
+
+        method = self.method_combo.currentText()
+        param_value = self.param_slider.value()
+
+        if method == "Thresholding":
+            _, self.segmented_image = cv2.threshold(
+                self.image, param_value, 255, cv2.THRESH_BINARY
+            )
+        elif method == "Active Contours":
+            s = np.linspace(0, 2 * np.pi, 400)
+            x = 100 + 50 * np.cos(s)
+            y = 100 + 50 * np.sin(s)
+            init = np.array([x, y]).T
+            self.segmented_image = active_contour(
+                self.image, init, max_iterations=param_value
+            )
+        elif method == "Region-Based":
+            denoised = ndi.gaussian_filter(self.image, sigma=param_value)
+            self.segmented_image = denoised > threshold_otsu(denoised)
+        elif method == "Watershed":
+            distance = ndi.distance_transform_edt(self.image)
+            local_maxi = peak_local_max(
+                distance, indices=False, footprint=np.ones((3, 3)), labels=self.image
+            )
+            markers, _ = ndi.label(local_maxi)
+            self.segmented_image = watershed(-distance, markers, mask=self.image)
+
+        self.display_image(self.segmented_image.astype(np.uint8) * 255)
+
+if __name__ == "__main__":
+    app = QApplication(sys.argv)
+    window = ImageSegmentationApp()
+    window.show()
+    sys.exit(app.exec_())

+ 178 - 0
src/labeler.py

@@ -0,0 +1,178 @@
+import sys
+import json
+from PyQt5.QtWidgets import (
+    QApplication, QMainWindow, QLabel, QPushButton, QColorDialog, QInputDialog,
+    QFileDialog, QMessageBox, QVBoxLayout, QHBoxLayout, QWidget
+)
+from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QColor
+from PyQt5.QtCore import Qt, QPoint
+
+class ImageLabel(QLabel):
+    def __init__(self, parent=None):
+        super().__init__(parent)
+        self.setMouseTracking(True)
+        self.image = QImage()
+        self.rois = []
+        self.current_roi = []
+        self.drawing = False
+        self.current_color = QColor(Qt.red)
+        self.current_label = "ROI"
+        self.undo_stack = []
+
+    def load_image(self, file_path):
+        self.image = QImage(file_path)
+        if self.image.isNull():
+            raise ValueError("Failed to load image.")
+        self.setPixmap(QPixmap.fromImage(self.image))
+        self.rois.clear()
+        self.current_roi.clear()
+        self.update()
+
+    def set_current_color(self, color):
+        self.current_color = color
+
+    def set_current_label(self, label):
+        self.current_label = label
+
+    def mousePressEvent(self, event):
+        if event.button() == Qt.LeftButton:
+            self.drawing = True
+            self.current_roi = [event.pos()]
+            self.update()
+
+    def mouseMoveEvent(self, event):
+        if self.drawing:
+            self.current_roi.append(event.pos())
+            self.update()
+
+    def mouseReleaseEvent(self, event):
+        if event.button() == Qt.LeftButton and self.drawing:
+            self.drawing = False
+            if len(self.current_roi) > 2:
+                self.rois.append({
+                    'label': self.current_label,
+                    'color': self.current_color,
+                    'points': self.current_roi
+                })
+                self.undo_stack.append(self.rois[-1])
+            self.current_roi = []
+            self.update()
+
+    def paintEvent(self, event):
+        super().paintEvent(event)
+        if not self.image.isNull():
+            painter = QPainter(self)
+            painter.drawImage(self.rect(), self.image, self.image.rect())
+            pen = QPen(Qt.red, 2, Qt.SolidLine)
+            painter.setPen(pen)
+            for roi in self.rois:
+                pen.setColor(roi['color'])
+                painter.setPen(pen)
+                points = [QPoint(p.x(), p.y()) for p in roi['points']]
+                painter.drawPolygon(*points)
+            if self.current_roi:
+                pen.setColor(Qt.blue)
+                painter.setPen(pen)
+                points = [QPoint(p.x(), p.y()) for p in self.current_roi]
+                painter.drawPolyline(*points)
+            painter.end()
+
+    def undo(self):
+        if self.undo_stack:
+            last_action = self.undo_stack.pop()
+            if last_action in self.rois:
+                self.rois.remove(last_action)
+            self.update()
+
+    def save_rois(self, file_path):
+        roi_data = []
+        for roi in self.rois:
+            roi_data.append({
+                'label': roi['label'],
+                'color': roi['color'].name(),
+                'points': [(point.x(), point.y()) for point in roi['points']]
+            })
+        with open(file_path, 'w') as file:
+            json.dump(roi_data, file, indent=4)
+
+class ROIDrawer(QMainWindow):
+    def __init__(self):
+        super().__init__()
+        self.initUI()
+
+    def initUI(self):
+        self.setWindowTitle('ROI Drawer')
+        self.setGeometry(100, 100, 1000, 600)
+
+        # Central widget
+        central_widget = QWidget(self)
+        self.setCentralWidget(central_widget)
+
+        # Layouts
+        main_layout = QHBoxLayout(central_widget)
+        image_layout = QVBoxLayout()
+        button_layout = QVBoxLayout()
+
+        # Image display
+        self.image_label = ImageLabel(self)
+        image_layout.addWidget(self.image_label)
+
+        # Buttons
+        load_button = QPushButton('Load Image', self)
+        load_button.clicked.connect(self.load_image)
+        button_layout.addWidget(load_button)
+
+        color_button = QPushButton('Select Color', self)
+        color_button.clicked.connect(self.select_color)
+        button_layout.addWidget(color_button)
+
+        label_button = QPushButton('Set Label', self)
+        label_button.clicked.connect(self.set_label)
+        button_layout.addWidget(label_button)
+
+        undo_button = QPushButton('Undo', self)
+        undo_button.clicked.connect(self.image_label.undo)
+        button_layout.addWidget(undo_button)
+
+        save_button = QPushButton('Save ROIs', self)
+        save_button.clicked.connect(self.save_rois)
+        button_layout.addWidget(save_button)
+
+        button_layout.addStretch(1)  # Push buttons to the top
+
+        # Add layouts to main layout
+        main_layout.addLayout(image_layout, 4)
+        main_layout.addLayout(button_layout, 1)
+
+    def load_image(self):
+        options = QFileDialog.Options()
+        file_path, _ = QFileDialog.getOpenFileName(
+            self, "Open Image File", "", "Images (*.png *.jpg *.bmp *.tiff)", options=options)
+        if file_path:
+            try:
+                self.image_label.load_image(file_path)
+            except ValueError as e:
+                QMessageBox.critical(self, "Error", str(e))
+
+    def select_color(self):
+        color = QColorDialog.getColor()
+        if color.isValid():
+            self.image_label.set_current_color(color)
+
+    def set_label(self):
+        label, ok = QInputDialog.getText(self, 'Set ROI Label', 'Enter label for ROI:')
+        if ok and label:
+            self.image_label.set_current_label(label)
+
+    def save_rois(self):
+        options = QFileDialog.Options()
+        file_path, _ = QFileDialog.getSaveFileName(
+            self, "Save ROIs", "", "JSON Files (*.json);;All Files (*)", options=options)
+        if file_path:
+            self.image_label.save_rois(file_path)
+
+if __name__ == '__main__':
+    app = QApplication(sys.argv)
+    drawer = ROIDrawer()
+    drawer.show()
+    sys.exit(app.exec_())

+ 67 - 0
src/read_data.py

@@ -0,0 +1,67 @@
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+import pydicom
+import cv2
+from skimage.filters import gaussian
+from skimage.segmentation import active_contour
+from skimage.measure import find_contours
+
+
+def get_mask_contour(mask):
+    mskray = mask.astype(np.uint8)
+    edged_mask = cv2.Canny(mskray, np.min(mskray), np.max(mskray))
+    return edged_mask
+
+
+folder_pth = r'C:\Users\user\Desktop\knee_seg\LITVYAK_D.I\LITVYAK_D.I\2025-01-22 181006'
+# folder_pth = r'E:\projects\knee_seg\Юшкевич_Оба колена\Yushkevich_L\YUSHKEVICH_A.V. 172\2025-01-31 090519'
+# folder_pth = r'E:\projects\knee_seg\LITVYAK_D.I\2025-01-22 181006'
+# folder_pth = r'E:\projects\knee_seg\data_tamplet\PA0\ST0\SE4'
+mx = []
+list_files = os.listdir(folder_pth)
+
+pd_fils = []
+for i, fname in enumerate(list_files):
+    file_pth = os.path.join(folder_pth, list_files[i])
+    file_dcm = pydicom.dcmread(file_pth)
+    print(file_dcm['0008', '103e'][0:9])
+    if 'Sag PD' in file_dcm['0008', '103e'][0:9]:
+        pd_fils.append(file_pth)
+
+
+for f in range (2, len(pd_fils) - 2):
+    file_dcm = pydicom.dcmread(pd_fils[f])
+    im = file_dcm.pixel_array
+    counts, bins = np.histogram(im, 100)
+    threshold = 0.50 * (np.max(im) - 0.05 * np.max(im))
+    segmented_im = np.copy(im)
+    segmented_im[segmented_im < threshold] = 0
+    contours = find_contours(segmented_im, level=0.5)
+
+    plt.subplot(1, 3, 1)
+    plt.imshow(im, cmap='gray')
+    plt.title ('Исходное изображение')
+    plt.subplot(1, 3, 2)
+    plt.imshow(im, 'gray')
+    plt.title ('Фильтрация по яркости')
+    contours_square = []
+    for c in range(len(contours)):
+        contours_square.append((contours[c].shape[0]*contours[c].shape[1]))
+    contours_thr = 0.3 * np.max (contours_square)
+    # plot all
+    for cn in range(len(contours)):
+
+        plt.plot(contours[cn][:, 1], contours[cn][:, 0], 'r')
+
+    plt.subplot(1, 3, 3)
+    plt.imshow(im, 'gray')
+    plt.title ('Фильтрация по размеру объекта')
+
+    for cn in range(len(contours)):
+        if contours[cn].shape[0]*contours[cn].shape[1]> contours_thr:
+            # contours[cn] = 0
+            plt.plot(contours[cn][:, 1], contours[cn][:, 0], 'r')
+
+    plt.show()
+

+ 353 - 0
src/shape_draw.py

@@ -0,0 +1,353 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2018, Felix Schill.
+# Distributed under the (new) BSD License. See LICENSE.txt for more info.
+# vispy: gallery 2
+"""
+Draw and Edit Shapes with Mouse
+===============================
+
+Simple demonstration of drawing and editing shapes with the mouse
+This demo implements mouse picking on visuals and markers using the
+vispy scene and "visual_at" mechanism.
+Left mouse button on empty space creates new objects. Objects can be
+selected by clicking, and moved by dragging. Dragging control points
+changes the size of the object.
+Vispy takes care of coordinate transforms from screen to ViewBox - the
+demo works on different zoom levels.
+Lastly, additional objects are added to the view in a fixed position as
+"buttons" to select which type of object is being created. Selecting
+the arrow symbol will switch into select/pan mode where the left drag
+moves the workplane or moves objects/controlpoints.
+"""
+
+import numpy as np
+
+from vispy import app, scene
+from vispy.color import Color
+
+
+class ControlPoints(scene.visuals.Compound):
+    def __init__(self, parent):
+        scene.visuals.Compound.__init__(self, [])
+        self.unfreeze()
+        self.parent = parent
+        self._center = [0, 0]
+        self._width = 0.0
+        self._height = 0.0
+        self.selected_cp = None
+        self.opposed_cp = None
+
+        self.control_points = [scene.visuals.Markers(parent=self)
+                               for i in range(0, 4)]
+        for c in self.control_points:
+            c.set_data(pos=np.array([[0, 0]],
+                                    dtype=np.float32),
+                       symbol="s",
+                       edge_color="red",
+                       size=6)
+            c.interactive = True
+        self.freeze()
+
+    def update_bounds(self):
+        self._center = [0.5 * (self.parent.bounds(0)[1] +
+                               self.parent.bounds(0)[0]),
+                        0.5 * (self.parent.bounds(1)[1] +
+                               self.parent.bounds(1)[0])]
+        self._width = self.parent.bounds(0)[1] - self.parent.bounds(0)[0]
+        self._height = self.parent.bounds(1)[1] - self.parent.bounds(1)[0]
+        self.update_points()
+
+    def update_points(self):
+        self.control_points[0].set_data(
+            pos=np.array([[self._center[0] - 0.5 * self._width,
+                           self._center[1] + 0.5 * self._height]]))
+        self.control_points[1].set_data(
+            pos=np.array([[self._center[0] + 0.5 * self._width,
+                           self._center[1] + 0.5 * self._height]]))
+        self.control_points[2].set_data(
+            pos=np.array([[self._center[0] + 0.5 * self._width,
+                           self._center[1] - 0.5 * self._height]]))
+        self.control_points[3].set_data(
+            pos=np.array([[self._center[0] - 0.5 * self._width,
+                           self._center[1] - 0.5 * self._height]]))
+
+    def select(self, val, obj=None):
+        self.visible(val)
+        self.selected_cp = None
+        self.opposed_cp = None
+
+        if obj is not None:
+            n_cp = len(self.control_points)
+            for i in range(0, n_cp):
+                c = self.control_points[i]
+                if c == obj:
+                    self.selected_cp = c
+                    self.opposed_cp = \
+                        self.control_points[int((i + n_cp / 2)) % n_cp]
+
+    def start_move(self, start):
+        None
+
+    def move(self, end):
+        if not self.parent.editable:
+            return
+        if self.selected_cp is not None:
+            self._width = 2 * (end[0] - self._center[0])
+            self._height = 2 * (end[1] - self._center[1])
+            self.update_points()
+            self.parent.update_from_controlpoints()
+
+    def visible(self, v):
+        for c in self.control_points:
+            c.visible = v
+
+    def get_center(self):
+        return self._center
+
+    def set_center(self, val):
+        self._center = val
+        self.update_points()
+
+
+class EditVisual(scene.visuals.Compound):
+    def __init__(self, editable=True, selectable=True, on_select_callback=None,
+                 callback_argument=None, *args, **kwargs):
+        scene.visuals.Compound.__init__(self, [], *args, **kwargs)
+        self.unfreeze()
+        self.editable = editable
+        self._selectable = selectable
+        self._on_select_callback = on_select_callback
+        self._callback_argument = callback_argument
+        self.control_points = ControlPoints(parent=self)
+        self.drag_reference = [0, 0]
+        self.freeze()
+
+    def add_subvisual(self, visual):
+        scene.visuals.Compound.add_subvisual(self, visual)
+        visual.interactive = True
+        self.control_points.update_bounds()
+        self.control_points.visible(False)
+
+    def select(self, val, obj=None):
+        if self.selectable:
+            self.control_points.visible(val)
+            if self._on_select_callback is not None:
+                self._on_select_callback(self._callback_argument)
+
+    def start_move(self, start):
+        self.drag_reference = start[0:2] - self.control_points.get_center()
+
+    def move(self, end):
+        if self.editable:
+            shift = end[0:2] - self.drag_reference
+            self.set_center(shift)
+
+    def update_from_controlpoints(self):
+        None
+
+    @property
+    def selectable(self):
+        return self._selectable
+
+    @selectable.setter
+    def selectable(self, val):
+        self._selectable = val
+
+    @property
+    def center(self):
+        return self.control_points.get_center()
+
+    @center.setter
+    # this method redirects to set_center. Override set_center in subclasses.
+    def center(self, val):
+        self.set_center(val)
+
+    # override this method in subclass
+    def set_center(self, val):
+        self.control_points.set_center(val[0:2])
+
+    def select_creation_controlpoint(self):
+        self.control_points.select(True, self.control_points.control_points[2])
+
+
+class EditRectVisual(EditVisual):
+    def __init__(self, center=[0, 0], width=20, height=20, *args, **kwargs):
+        EditVisual.__init__(self, *args, **kwargs)
+        self.unfreeze()
+        self.rect = scene.visuals.Rectangle(center=center, width=width,
+                                            height=height,
+                                            color=Color("#e88834"),
+                                            border_color="white",
+                                            radius=0, parent=self)
+        self.rect.interactive = True
+
+        self.freeze()
+        self.add_subvisual(self.rect)
+        self.control_points.update_bounds()
+        self.control_points.visible(False)
+
+    def set_center(self, val):
+        self.control_points.set_center(val[0:2])
+        self.rect.center = val[0:2]
+
+    def update_from_controlpoints(self):
+        try:
+            self.rect.width = abs(self.control_points._width)
+        except ValueError:
+            None
+        try:
+            self.rect.height = abs(self.control_points._height)
+        except ValueError:
+            None
+
+
+class EditEllipseVisual(EditVisual):
+    def __init__(self, center=[0, 0], radius=[2, 2], *args, **kwargs):
+        EditVisual.__init__(self, *args, **kwargs)
+        self.unfreeze()
+        self.ellipse = scene.visuals.Ellipse(center=center, radius=radius,
+                                             color=Color("#e88834"),
+                                             border_color="white",
+                                             parent=self)
+        self.ellipse.interactive = True
+
+        self.freeze()
+        self.add_subvisual(self.ellipse)
+        self.control_points.update_bounds()
+        self.control_points.visible(False)
+
+    def set_center(self, val):
+        self.control_points.set_center(val)
+        self.ellipse.center = val
+
+    def update_from_controlpoints(self):
+        try:
+            self.ellipse.radius = [0.5 * abs(self.control_points._width),
+                                   0.5 * abs(self.control_points._height)]
+        except ValueError:
+            None
+
+
+class Canvas(scene.SceneCanvas):
+    """ A simple test canvas for drawing demo """
+
+    def __init__(self):
+        scene.SceneCanvas.__init__(self, keys='interactive',
+                                   size=(800, 800))
+
+        self.unfreeze()
+
+        self.view = self.central_widget.add_view()
+        self.view.camera = scene.PanZoomCamera(rect=(-100, -100, 200, 200),
+                                               aspect=1.0)
+        # the left mouse button pan has to be disabled in the camera, as it
+        # interferes with dragging line points
+        # Proposed change in camera: make mouse buttons configurable
+        self.view.camera._viewbox.events.mouse_move.disconnect(
+            self.view.camera.viewbox_mouse_event)
+
+        scene.visuals.Text("Click and drag to add objects, " +
+                           "right-click to delete.",
+                           color='w',
+                           anchor_x='left',
+                           parent=self.view,
+                           pos=(20, 30))
+
+        self.select_arrow = \
+            EditVisual(parent=self.view, editable=False,
+                       on_select_callback=self.set_creation_mode,
+                       callback_argument=None)
+        arrow = scene.visuals.Arrow(parent=self.select_arrow,
+                                    pos=np.array([[50, 60], [60, 70]]),
+                                    arrows=np.array([[60, 70, 50, 60]]),
+                                    width=5, arrow_size=15.0,
+                                    arrow_type="angle_60",
+                                    color="w",
+                                    arrow_color="w",
+                                    method="agg"
+                                    )
+        self.select_arrow.add_subvisual(arrow)
+
+        self.rect_button = \
+            EditRectVisual(parent=self.view, editable=False,
+                           on_select_callback=self.set_creation_mode,
+                           callback_argument=EditRectVisual,
+                           center=[50, 120], width=30, height=30)
+        self.ellipse_button = \
+            EditEllipseVisual(parent=self.view,
+                              editable=False,
+                              on_select_callback=self.set_creation_mode,
+                              callback_argument=EditEllipseVisual,
+                              center=[50, 170],
+                              radius=[15, 10])
+
+        self.objects = []
+        self.show()
+        self.selected_point = None
+        self.selected_object = None
+        self.creation_mode = EditRectVisual
+        self.mouse_start_pos = [0, 0]
+        scene.visuals.GridLines(parent=self.view.scene)
+        self.freeze()
+
+    def set_creation_mode(self, object_kind):
+        self.creation_mode = object_kind
+
+    def on_mouse_press(self, event):
+
+        tr = self.scene.node_transform(self.view.scene)
+        pos = tr.map(event.pos)
+        self.view.interactive = False
+        selected = self.visual_at(event.pos)
+        self.view.interactive = True
+        if self.selected_object is not None:
+            self.selected_object.select(False)
+            self.selected_object = None
+
+        if event.button == 1:
+            if selected is not None:
+                self.selected_object = selected.parent
+                # update transform to selected object
+                tr = self.scene.node_transform(self.selected_object)
+                pos = tr.map(event.pos)
+
+                self.selected_object.select(True, obj=selected)
+                self.selected_object.start_move(pos)
+                self.mouse_start_pos = event.pos
+
+            # create new object:
+            if self.selected_object is None and self.creation_mode is not None:
+                # new_object = EditRectVisual(parent=self.view.scene)
+                new_object = self.creation_mode(parent=self.view.scene)
+                self.objects.append(new_object)
+                new_object.select_creation_controlpoint()
+                new_object.set_center(pos[0:2])
+                self.selected_object = new_object.control_points
+
+        if event.button == 2:  # right button deletes object
+            if selected is not None and selected.parent in self.objects:
+                self.objects.remove(selected.parent)
+                selected.parent.parent = None
+                self.selected_object = None
+
+    def on_mouse_move(self, event):
+
+        if event.button == 1:
+            if self.selected_object is not None:
+                self.view.camera._viewbox.events.mouse_move.disconnect(
+                    self.view.camera.viewbox_mouse_event)
+                # update transform to selected object
+                tr = self.scene.node_transform(self.selected_object)
+                pos = tr.map(event.pos)
+
+                self.selected_object.move(pos[0:2])
+            else:
+                self.view.camera._viewbox.events.mouse_move.connect(
+                    self.view.camera.viewbox_mouse_event)
+        else:
+            None
+
+
+if __name__ == '__main__':
+    canvas = Canvas()
+    app.run()