main.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import torch
  2. from torch.utils.data import DataLoader
  3. from flow_model.model.network.hf import HierarchyFlow
  4. from flow_model.model.utils.dataset import get_dataset
  5. from flow_model.utils import *
  6. from flow_model.model.trainers.MR_signal_model import *
  7. import yaml
  8. from easydict import EasyDict
  9. #import matplotlib.pyplot as plt
  10. import h5py
  11. from datetime import datetime
  12. #from cmap import Colormap
  13. def flow_model(upload_cfg, mode='base', device_pref=None):
  14. """Запускает генерацию фантома.
  15. Parameters:
  16. upload_cfg: конфигурация загрузки/датасета.
  17. mode: 'base' или 'upload'.
  18. device_pref: None|'cpu'|'cuda' — предпочтительное устройство.
  19. - None: автоматически выбрать CUDA, если доступна.
  20. - 'cuda': пытаться использовать CUDA, если недоступна — фоллбек на CPU.
  21. - 'cpu': принудительно CPU.
  22. """
  23. if device_pref == 'cpu':
  24. device = torch.device('cpu')
  25. elif device_pref == 'cuda':
  26. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  27. else:
  28. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  29. with open('./flow_model/configs/config.yaml') as f:
  30. cfg = yaml.load(f, Loader=yaml.FullLoader)
  31. cfg = EasyDict(cfg)
  32. model = HierarchyFlow(cfg.network.pad_size, cfg.network.in_channel, cfg.network.out_channels,
  33. cfg.network.weight_type)
  34. model = model.to(device)
  35. optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
  36. model, optimizer, resumed_step = load_checkpoint('flow_model/checkpoints/27900.ckpt.pth.tar',
  37. model, optimizer, device)
  38. if mode == 'base':
  39. test_dataset = get_dataset(cfg.dataset.test)
  40. else:
  41. test_dataset = get_dataset(upload_cfg.dataset.calc)
  42. test_loader = DataLoader(
  43. test_dataset,
  44. batch_size=1,
  45. shuffle=False)
  46. model.eval()
  47. with torch.no_grad():
  48. for batch_id, batch in enumerate(test_loader):
  49. content_images = batch[0].to(device)
  50. style_images = batch[1].to(device)
  51. names = batch[2]
  52. not_normalized_image = batch[3].cpu()
  53. max_vals = batch[4].cpu()
  54. outputs = model(content_images, style_images)
  55. outputs = torch.clamp(outputs, 0, 1)
  56. denormalized_outputs = denormalize_per_image(outputs, [0.0, 0.0, 0.0], [3.5, 0.45, 1.0])
  57. denormalized_outputs = denormalized_outputs.cpu().numpy()
  58. pd = denormalized_outputs[0, 2, :, :][3:253, 131:381]
  59. pd[pd < 0.19] = 0
  60. t1 = denormalized_outputs[0, 0, :, :][3:253, 131:381]
  61. t1[pd == 0] = 0
  62. t2 = denormalized_outputs[0, 1, :, :][3:253, 131:381]
  63. t2[pd == 0] = 0
  64. t1_for_ph = t1 * 1e3
  65. t2_for_ph = t2 * 1e3
  66. PHANTOM_SIZE = (250, 250)
  67. OFFSET = np.array([0.0, 0.0, 0.0])
  68. RESOLUTION = np.array([1.0, 1.0, 1.0])
  69. zer = np.zeros(t2.shape)
  70. t2_s_for_ph = t2_for_ph.copy()
  71. # Имя файла фантома по шаблону phantom_YYYYMMDD_HHMMSS_microsec.h5
  72. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
  73. phantom_fname = f'phantom_{timestamp}.h5'
  74. create_h5_file([pd, t1_for_ph, t2_for_ph, t2_s_for_ph, zer],
  75. phantom_fname,
  76. 'flow_model/phantoms_h5',
  77. OFFSET,
  78. RESOLUTION,
  79. PHANTOM_SIZE)
  80. # В режиме загрузки (upload) генерируем только один фантом,
  81. # чтобы избежать создания нескольких файлов при проверке пайплайна.
  82. if mode == 'upload':
  83. break
  84. if __name__ == "__main__":
  85. downloaded_dataset = []
  86. flow_model(downloaded_dataset, mode='base')