dataset.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from PIL import Image
  2. import torch.utils.data as data
  3. from torchvision import transforms
  4. import random
  5. import os
  6. import cv2 as cv
  7. import torch.nn.functional as F
  8. import torch
  9. import numpy as np
  10. def get_imgs_from_dir(path_a, path_b):
  11. list_a = []
  12. for root, dir, files in os.walk(path_a):
  13. for file in files:
  14. list_a.append(f'{path_a}/{file}')
  15. list_b = []
  16. list_b.append(path_b)
  17. list_b = list_b * len(list_a)
  18. return list_a, list_b
  19. class PadToSize:
  20. def __init__(self, size, fill=0):
  21. """
  22. size: (target_height, target_width) — желаемый размер после паддинга
  23. fill: значение для паддинга (по умолчанию 0)
  24. """
  25. if isinstance(size, int):
  26. self.target_h = self.target_w = size
  27. else:
  28. self.target_h, self.target_w = size
  29. self.fill = fill
  30. def __call__(self, img: torch.Tensor) -> torch.Tensor:
  31. """
  32. img: тензор формы [C, H, W]
  33. Возвращает: тензор [C, target_h, target_w], центрирован с нулевым паддингом вокруг, если исходный меньше.
  34. Если исходный больше по какой-то размерности — по этой размерности не обрезает, оставляет как есть.
  35. """
  36. if img.ndim != 3:
  37. raise ValueError(f"Ожидается тензор формы [C,H,W], получили {img.shape}")
  38. c, h, w = img.shape
  39. pad_h = max(self.target_h - h, 0)
  40. pad_w = max(self.target_w - w, 0)
  41. pad_top = pad_h // 2
  42. pad_bottom = pad_h - pad_top
  43. pad_left = pad_w // 2
  44. pad_right = pad_w - pad_left
  45. # порядок для F.pad: (pad_left, pad_right, pad_top, pad_bottom)
  46. padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=self.fill)
  47. return padded
  48. def normalize_per_image(img: np.ndarray, eps=1e-8):
  49. """
  50. img: numpy array, shape (H, W, C), dtype uint8 or float (если uint8 — будет приведён)
  51. Возвращает: img_norm — float32 в диапазоне [0,1], и массивы min_vals, max_vals (для каждого канала).
  52. Требуется, чтобы изображение уже было трёхмерным (H, W, C) — как и было в исходной логике.
  53. """
  54. # Приведём к float32
  55. img_c = img.astype(np.float32, copy=True)
  56. if img_c.ndim != 3:
  57. raise ValueError(f"Ожидается изображение 3D (H,W,C), получено shape={img.shape}")
  58. norm_img = np.zeros_like(img_c)
  59. H, W, C = img_c.shape
  60. min_vals = np.zeros((C,), dtype=np.float32)
  61. max_vals = np.zeros((C,), dtype=np.float32)
  62. # Для каждого канала отдельно
  63. for c in range(C):
  64. channel = img_c[:, :, c].copy()
  65. cmin = channel.min()
  66. cmax = channel.max()
  67. min_vals[c] = cmin
  68. max_vals[c] = cmax
  69. # Защита от случая, где всё одинаковое
  70. denom = (cmax - cmin + eps)
  71. norm_img[:, :, c] = (channel - cmin) / denom
  72. # теперь img в [0,1] по каждому каналу
  73. return norm_img, min_vals, max_vals
  74. class BaseDataset(data.Dataset):
  75. def __init__(self, cfg):
  76. rootA = cfg.rootA
  77. rootB = cfg.rootB
  78. self.return_name = cfg.return_name
  79. transform_list = []
  80. transform_list.append(transforms.ToTensor())
  81. transform_list.append(transforms.Resize((250, 250)))
  82. transform_list.append(PadToSize((256, 512)))
  83. transform = transforms.Compose(transform_list)
  84. imgsA = []
  85. imgsB = []
  86. imgs_A_1, imgs_B_1 = get_imgs_from_dir(rootA, rootB)
  87. # Фильтруем входы A: оставляем только файлы с формой (H, W, C)
  88. filtered_A = []
  89. for p in imgs_A_1:
  90. try:
  91. arr = np.load(p, allow_pickle=False)
  92. if getattr(arr, 'ndim', None) == 3:
  93. filtered_A.append(p)
  94. except Exception:
  95. # пропускаем некорректные файлы
  96. continue
  97. imgs_A_1 = filtered_A
  98. imgs_B_1 = [rootB] * len(imgs_A_1)
  99. imgsA += imgs_A_1
  100. imgsB += imgs_B_1
  101. random.shuffle(imgs_A_1)
  102. random.shuffle(imgs_B_1)
  103. total_len = len(imgsA)
  104. self.rootA = rootA
  105. self.rootB = rootB
  106. self.imgsA = imgsA # [:int(keep_percent*total_len)+1]
  107. self.imgsB = imgsB # [:int(keep_percent*total_len)+1]
  108. self.transform = transform
  109. def __getitem__(self, index):
  110. # index = 0
  111. pathA = self.imgsA[index]
  112. pathB = self.imgsB[index]
  113. imgA = np.load(pathA)
  114. imgA_not_normalized = imgA.copy()
  115. imgB = np.load(pathB)
  116. pd = imgB[:, :, 0]
  117. t1 = imgB[:, :, 1] * 1e-3
  118. t2 = imgB[:, :, 2] * 1e-3
  119. maps_list = [t1, t2, pd]
  120. maps_arr = np.array(maps_list)
  121. imgB = maps_arr.transpose(1, 2, 0)
  122. imgA, imgA_min_vals, imgA_max_vals = normalize_per_image(imgA, eps=1e-8)
  123. imgB, _, _ = normalize_per_image(imgB, eps=1e-8)
  124. name = pathA.split('/')[-1]
  125. imgA = self.transform(imgA)
  126. imgB = self.transform(imgB)
  127. imgA_not_normalized = self.transform(imgA_not_normalized)
  128. if not self.return_name:
  129. return imgA, imgB, imgA_not_normalized, imgA_max_vals
  130. else:
  131. return imgA, imgB, '{}_to_{}'.format(pathA.split('/')[-1].split('.')[0], pathB.split('/')[-1].split('.')[0]), imgA_not_normalized, imgA_max_vals
  132. def __len__(self):
  133. return len(self.imgsA)
  134. def get_dataset(cfg):
  135. dataset = BaseDataset(cfg)
  136. return dataset