from PIL import Image import torch.utils.data as data from torchvision import transforms import random import os import cv2 as cv import torch.nn.functional as F import torch import numpy as np def get_imgs_from_dir(path_a, path_b): list_a = [] for root, dir, files in os.walk(path_a): for file in files: list_a.append(f'{path_a}/{file}') list_b = [] list_b.append(path_b) list_b = list_b * len(list_a) return list_a, list_b class PadToSize: def __init__(self, size, fill=0): """ size: (target_height, target_width) — желаемый размер после паддинга fill: значение для паддинга (по умолчанию 0) """ if isinstance(size, int): self.target_h = self.target_w = size else: self.target_h, self.target_w = size self.fill = fill def __call__(self, img: torch.Tensor) -> torch.Tensor: """ img: тензор формы [C, H, W] Возвращает: тензор [C, target_h, target_w], центрирован с нулевым паддингом вокруг, если исходный меньше. Если исходный больше по какой-то размерности — по этой размерности не обрезает, оставляет как есть. """ if img.ndim != 3: raise ValueError(f"Ожидается тензор формы [C,H,W], получили {img.shape}") c, h, w = img.shape pad_h = max(self.target_h - h, 0) pad_w = max(self.target_w - w, 0) pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_left = pad_w // 2 pad_right = pad_w - pad_left # порядок для F.pad: (pad_left, pad_right, pad_top, pad_bottom) padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=self.fill) return padded def normalize_per_image(img: np.ndarray, eps=1e-8): """ img: numpy array, shape (H, W, C), dtype uint8 or float (если uint8 — будет приведён) Возвращает: img_norm — float32 в диапазоне [0,1], и массивы min_vals, max_vals (для каждого канала). Требуется, чтобы изображение уже было трёхмерным (H, W, C) — как и было в исходной логике. """ # Приведём к float32 img_c = img.astype(np.float32, copy=True) if img_c.ndim != 3: raise ValueError(f"Ожидается изображение 3D (H,W,C), получено shape={img.shape}") norm_img = np.zeros_like(img_c) H, W, C = img_c.shape min_vals = np.zeros((C,), dtype=np.float32) max_vals = np.zeros((C,), dtype=np.float32) # Для каждого канала отдельно for c in range(C): channel = img_c[:, :, c].copy() cmin = channel.min() cmax = channel.max() min_vals[c] = cmin max_vals[c] = cmax # Защита от случая, где всё одинаковое denom = (cmax - cmin + eps) norm_img[:, :, c] = (channel - cmin) / denom # теперь img в [0,1] по каждому каналу return norm_img, min_vals, max_vals class BaseDataset(data.Dataset): def __init__(self, cfg): rootA = cfg.rootA rootB = cfg.rootB self.return_name = cfg.return_name transform_list = [] transform_list.append(transforms.ToTensor()) transform_list.append(transforms.Resize((250, 250))) transform_list.append(PadToSize((256, 512))) transform = transforms.Compose(transform_list) imgsA = [] imgsB = [] imgs_A_1, imgs_B_1 = get_imgs_from_dir(rootA, rootB) # Фильтруем входы A: оставляем только файлы с формой (H, W, C) filtered_A = [] for p in imgs_A_1: try: arr = np.load(p, allow_pickle=False) if getattr(arr, 'ndim', None) == 3: filtered_A.append(p) except Exception: # пропускаем некорректные файлы continue imgs_A_1 = filtered_A imgs_B_1 = [rootB] * len(imgs_A_1) imgsA += imgs_A_1 imgsB += imgs_B_1 random.shuffle(imgs_A_1) random.shuffle(imgs_B_1) total_len = len(imgsA) self.rootA = rootA self.rootB = rootB self.imgsA = imgsA # [:int(keep_percent*total_len)+1] self.imgsB = imgsB # [:int(keep_percent*total_len)+1] self.transform = transform def __getitem__(self, index): # index = 0 pathA = self.imgsA[index] pathB = self.imgsB[index] imgA = np.load(pathA) imgA_not_normalized = imgA.copy() imgB = np.load(pathB) pd = imgB[:, :, 0] t1 = imgB[:, :, 1] * 1e-3 t2 = imgB[:, :, 2] * 1e-3 maps_list = [t1, t2, pd] maps_arr = np.array(maps_list) imgB = maps_arr.transpose(1, 2, 0) imgA, imgA_min_vals, imgA_max_vals = normalize_per_image(imgA, eps=1e-8) imgB, _, _ = normalize_per_image(imgB, eps=1e-8) name = pathA.split('/')[-1] imgA = self.transform(imgA) imgB = self.transform(imgB) imgA_not_normalized = self.transform(imgA_not_normalized) if not self.return_name: return imgA, imgB, imgA_not_normalized, imgA_max_vals else: return imgA, imgB, '{}_to_{}'.format(pathA.split('/')[-1].split('.')[0], pathB.split('/')[-1].split('.')[0]), imgA_not_normalized, imgA_max_vals def __len__(self): return len(self.imgsA) def get_dataset(cfg): dataset = BaseDataset(cfg) return dataset