| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import torch
- from torch.utils.data import DataLoader
- from flow_model.model.network.hf import HierarchyFlow
- from flow_model.model.utils.dataset import get_dataset
- from flow_model.utils import *
- from flow_model.model.trainers.MR_signal_model import *
- import yaml
- from easydict import EasyDict
- #import matplotlib.pyplot as plt
- import h5py
- from datetime import datetime
- #from cmap import Colormap
- def flow_model(upload_cfg, mode='base', device_pref=None):
- """Запускает генерацию фантома.
- Parameters:
- upload_cfg: конфигурация загрузки/датасета.
- mode: 'base' или 'upload'.
- device_pref: None|'cpu'|'cuda' — предпочтительное устройство.
- - None: автоматически выбрать CUDA, если доступна.
- - 'cuda': пытаться использовать CUDA, если недоступна — фоллбек на CPU.
- - 'cpu': принудительно CPU.
- """
- if device_pref == 'cpu':
- device = torch.device('cpu')
- elif device_pref == 'cuda':
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- else:
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- with open('./flow_model/configs/config.yaml') as f:
- cfg = yaml.load(f, Loader=yaml.FullLoader)
- cfg = EasyDict(cfg)
- model = HierarchyFlow(cfg.network.pad_size, cfg.network.in_channel, cfg.network.out_channels,
- cfg.network.weight_type)
- model = model.to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
- model, optimizer, resumed_step = load_checkpoint('flow_model/checkpoints/27900.ckpt.pth.tar',
- model, optimizer, device)
- if mode == 'base':
- test_dataset = get_dataset(cfg.dataset.test)
- else:
- test_dataset = get_dataset(upload_cfg.dataset.calc)
- test_loader = DataLoader(
- test_dataset,
- batch_size=1,
- shuffle=False)
- model.eval()
- with torch.no_grad():
- for batch_id, batch in enumerate(test_loader):
- content_images = batch[0].to(device)
- style_images = batch[1].to(device)
- names = batch[2]
- not_normalized_image = batch[3].cpu()
- max_vals = batch[4].cpu()
- outputs = model(content_images, style_images)
- outputs = torch.clamp(outputs, 0, 1)
- denormalized_outputs = denormalize_per_image(outputs, [0.0, 0.0, 0.0], [3.5, 0.45, 1.0])
- denormalized_outputs = denormalized_outputs.cpu().numpy()
- pd = denormalized_outputs[0, 2, :, :][3:253, 131:381]
- pd[pd < 0.19] = 0
- t1 = denormalized_outputs[0, 0, :, :][3:253, 131:381]
- t1[pd == 0] = 0
- t2 = denormalized_outputs[0, 1, :, :][3:253, 131:381]
- t2[pd == 0] = 0
- t1_for_ph = t1 * 1e3
- t2_for_ph = t2 * 1e3
- PHANTOM_SIZE = (250, 250)
- OFFSET = np.array([0.0, 0.0, 0.0])
- RESOLUTION = np.array([1.0, 1.0, 1.0])
- zer = np.zeros(t2.shape)
- t2_s_for_ph = t2_for_ph.copy()
- # Имя файла фантома по шаблону phantom_YYYYMMDD_HHMMSS_microsec.h5
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
- phantom_fname = f'phantom_{timestamp}.h5'
- create_h5_file([pd, t1_for_ph, t2_for_ph, t2_s_for_ph, zer],
- phantom_fname,
- 'flow_model/phantoms_h5',
- OFFSET,
- RESOLUTION,
- PHANTOM_SIZE)
- # В режиме загрузки (upload) генерируем только один фантом,
- # чтобы избежать создания нескольких файлов при проверке пайплайна.
- if mode == 'upload':
- break
- if __name__ == "__main__":
- downloaded_dataset = []
- flow_model(downloaded_dataset, mode='base')
|