| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- from PIL import Image
- import torch.utils.data as data
- from torchvision import transforms
- import random
- import os
- import cv2 as cv
- import torch.nn.functional as F
- import torch
- import numpy as np
- def get_imgs_from_dir(path_a, path_b):
- list_a = []
- for root, dir, files in os.walk(path_a):
- for file in files:
- list_a.append(f'{path_a}/{file}')
- list_b = []
- list_b.append(path_b)
- list_b = list_b * len(list_a)
- return list_a, list_b
- class PadToSize:
- def __init__(self, size, fill=0):
- """
- size: (target_height, target_width) — желаемый размер после паддинга
- fill: значение для паддинга (по умолчанию 0)
- """
- if isinstance(size, int):
- self.target_h = self.target_w = size
- else:
- self.target_h, self.target_w = size
- self.fill = fill
- def __call__(self, img: torch.Tensor) -> torch.Tensor:
- """
- img: тензор формы [C, H, W]
- Возвращает: тензор [C, target_h, target_w], центрирован с нулевым паддингом вокруг, если исходный меньше.
- Если исходный больше по какой-то размерности — по этой размерности не обрезает, оставляет как есть.
- """
- if img.ndim != 3:
- raise ValueError(f"Ожидается тензор формы [C,H,W], получили {img.shape}")
- c, h, w = img.shape
- pad_h = max(self.target_h - h, 0)
- pad_w = max(self.target_w - w, 0)
- pad_top = pad_h // 2
- pad_bottom = pad_h - pad_top
- pad_left = pad_w // 2
- pad_right = pad_w - pad_left
- # порядок для F.pad: (pad_left, pad_right, pad_top, pad_bottom)
- padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=self.fill)
- return padded
- def normalize_per_image(img: np.ndarray, eps=1e-8):
- """
- img: numpy array, shape (H, W, C), dtype uint8 or float (если uint8 — будет приведён)
- Возвращает: img_norm — float32 в диапазоне [0,1], и массивы min_vals, max_vals (для каждого канала).
- Требуется, чтобы изображение уже было трёхмерным (H, W, C) — как и было в исходной логике.
- """
- # Приведём к float32
- img_c = img.astype(np.float32, copy=True)
- if img_c.ndim != 3:
- raise ValueError(f"Ожидается изображение 3D (H,W,C), получено shape={img.shape}")
- norm_img = np.zeros_like(img_c)
- H, W, C = img_c.shape
- min_vals = np.zeros((C,), dtype=np.float32)
- max_vals = np.zeros((C,), dtype=np.float32)
- # Для каждого канала отдельно
- for c in range(C):
- channel = img_c[:, :, c].copy()
- cmin = channel.min()
- cmax = channel.max()
- min_vals[c] = cmin
- max_vals[c] = cmax
- # Защита от случая, где всё одинаковое
- denom = (cmax - cmin + eps)
- norm_img[:, :, c] = (channel - cmin) / denom
- # теперь img в [0,1] по каждому каналу
- return norm_img, min_vals, max_vals
- class BaseDataset(data.Dataset):
- def __init__(self, cfg):
- rootA = cfg.rootA
- rootB = cfg.rootB
- self.return_name = cfg.return_name
- transform_list = []
- transform_list.append(transforms.ToTensor())
- transform_list.append(transforms.Resize((250, 250)))
- transform_list.append(PadToSize((256, 512)))
- transform = transforms.Compose(transform_list)
- imgsA = []
- imgsB = []
- imgs_A_1, imgs_B_1 = get_imgs_from_dir(rootA, rootB)
- # Фильтруем входы A: оставляем только файлы с формой (H, W, C)
- filtered_A = []
- for p in imgs_A_1:
- try:
- arr = np.load(p, allow_pickle=False)
- if getattr(arr, 'ndim', None) == 3:
- filtered_A.append(p)
- except Exception:
- # пропускаем некорректные файлы
- continue
- imgs_A_1 = filtered_A
- imgs_B_1 = [rootB] * len(imgs_A_1)
- imgsA += imgs_A_1
- imgsB += imgs_B_1
- random.shuffle(imgs_A_1)
- random.shuffle(imgs_B_1)
- total_len = len(imgsA)
- self.rootA = rootA
- self.rootB = rootB
- self.imgsA = imgsA # [:int(keep_percent*total_len)+1]
- self.imgsB = imgsB # [:int(keep_percent*total_len)+1]
- self.transform = transform
- def __getitem__(self, index):
- # index = 0
- pathA = self.imgsA[index]
- pathB = self.imgsB[index]
- imgA = np.load(pathA)
- imgA_not_normalized = imgA.copy()
- imgB = np.load(pathB)
- pd = imgB[:, :, 0]
- t1 = imgB[:, :, 1] * 1e-3
- t2 = imgB[:, :, 2] * 1e-3
- maps_list = [t1, t2, pd]
- maps_arr = np.array(maps_list)
- imgB = maps_arr.transpose(1, 2, 0)
- imgA, imgA_min_vals, imgA_max_vals = normalize_per_image(imgA, eps=1e-8)
- imgB, _, _ = normalize_per_image(imgB, eps=1e-8)
- name = pathA.split('/')[-1]
- imgA = self.transform(imgA)
- imgB = self.transform(imgB)
- imgA_not_normalized = self.transform(imgA_not_normalized)
- if not self.return_name:
- return imgA, imgB, imgA_not_normalized, imgA_max_vals
- else:
- return imgA, imgB, '{}_to_{}'.format(pathA.split('/')[-1].split('.')[0], pathB.split('/')[-1].split('.')[0]), imgA_not_normalized, imgA_max_vals
- def __len__(self):
- return len(self.imgsA)
- def get_dataset(cfg):
- dataset = BaseDataset(cfg)
- return dataset
|