utils.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import numpy as np
  2. import os
  3. import torch
  4. from torchvision.utils import save_image
  5. from torch.utils.data import DataLoader
  6. from flow_model.model.losses import VGGLoss
  7. from flow_model.model.network.hf import HierarchyFlow
  8. from flow_model.model.utils.dataset import get_dataset
  9. from flow_model.model.utils.sampler import DistributedGivenIterationSampler, DistributedTestSampler
  10. from tensorboardX import SummaryWriter
  11. import logging
  12. from flow_model.model.utils.log_helper import init_log
  13. from flow_model.model.trainers.MR_signal_model import *
  14. import yaml
  15. from easydict import EasyDict
  16. import matplotlib.pyplot as plt
  17. import h5py
  18. from cmap import Colormap
  19. def transform(image):
  20. max_val = np.max(image)
  21. min_val = np.min(image)
  22. scale = 2.0 / (max_val - min_val + 1e-8) # Добавляем небольшое значение для избежания деления на ноль
  23. offset = -1.0 - min_val * scale
  24. # Применяем масштабирование
  25. image_scaled = image * scale + offset
  26. return torch.from_numpy(image_scaled).unsqueeze(0).unsqueeze(0)
  27. def load_checkpoint(checkpoint_fpath, model, optimizer, location):
  28. checkpoint = torch.load(checkpoint_fpath, map_location=location)
  29. model.load_state_dict(checkpoint['state_dict'])
  30. optimizer.load_state_dict(checkpoint['optimizer'])
  31. return model, optimizer, checkpoint['step']
  32. def denormalize_per_image(img_tensor, min_vals, max_vals):
  33. img_denorm = torch.zeros_like(img_tensor)
  34. img_denorm[:, 0] = img_tensor[:, 0] * (max_vals[0] - min_vals[0]) + min_vals[0]
  35. img_denorm[:, 1] = img_tensor[:, 1] * (max_vals[1] - min_vals[1]) + min_vals[1]
  36. img_denorm[:, 2] = img_tensor[:, 2] * (max_vals[2] - min_vals[2]) + min_vals[2]
  37. return img_denorm
  38. def create_h5_file(list_of_phantoms, phantom_name, dir_for_save, offset, resolution, ph_size):
  39. offset = np.expand_dims(offset, 1)
  40. resolution = np.expand_dims(resolution, 1)
  41. rescaled_phantoms = []
  42. for i in range(len(list_of_phantoms)):
  43. if (i == 0) or (i == 4):
  44. rescaled_phantoms.append(list_of_phantoms[i])
  45. else:
  46. #r_ph = list_of_phantoms[i] * 1e3
  47. r_ph = list_of_phantoms[i]
  48. r_ph = np.where(r_ph > 0, 1/r_ph, 0)
  49. rescaled_phantoms.append(r_ph)
  50. five_phant = np.array(rescaled_phantoms).transpose(1,2,0)
  51. with h5py.File(f'{dir_for_save}/{phantom_name}','w') as f:
  52. grp = f.create_group("sample")
  53. dset = grp.create_dataset('data', (ph_size[0], ph_size[1], 5), dtype="f8")
  54. dset[:,:, :] = five_phant[:,:, :]
  55. dset_1 = grp.create_dataset('offset', (3, 1), dtype="f8")
  56. dset_1[:, :] = offset[:, :]
  57. dset_2 = grp.create_dataset('resolution', (3, 1), dtype="f8")
  58. dset_2[:, :] = resolution[:, :]