import torch from torch.utils.data import DataLoader from flow_model.model.network.hf import HierarchyFlow from flow_model.model.utils.dataset import get_dataset from flow_model.utils import * from flow_model.model.trainers.MR_signal_model import * import yaml from easydict import EasyDict #import matplotlib.pyplot as plt import h5py from datetime import datetime #from cmap import Colormap def flow_model(upload_cfg, mode='base', device_pref=None): """Запускает генерацию фантома. Parameters: upload_cfg: конфигурация загрузки/датасета. mode: 'base' или 'upload'. device_pref: None|'cpu'|'cuda' — предпочтительное устройство. - None: автоматически выбрать CUDA, если доступна. - 'cuda': пытаться использовать CUDA, если недоступна — фоллбек на CPU. - 'cpu': принудительно CPU. """ if device_pref == 'cpu': device = torch.device('cpu') elif device_pref == 'cuda': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') with open('./flow_model/configs/config.yaml') as f: cfg = yaml.load(f, Loader=yaml.FullLoader) cfg = EasyDict(cfg) model = HierarchyFlow(cfg.network.pad_size, cfg.network.in_channel, cfg.network.out_channels, cfg.network.weight_type) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) model, optimizer, resumed_step = load_checkpoint('flow_model/checkpoints/27900.ckpt.pth.tar', model, optimizer, device) if mode == 'base': test_dataset = get_dataset(cfg.dataset.test) else: test_dataset = get_dataset(upload_cfg.dataset.calc) test_loader = DataLoader( test_dataset, batch_size=1, shuffle=False) model.eval() with torch.no_grad(): for batch_id, batch in enumerate(test_loader): content_images = batch[0].to(device) style_images = batch[1].to(device) names = batch[2] not_normalized_image = batch[3].cpu() max_vals = batch[4].cpu() outputs = model(content_images, style_images) outputs = torch.clamp(outputs, 0, 1) denormalized_outputs = denormalize_per_image(outputs, [0.0, 0.0, 0.0], [3.5, 0.45, 1.0]) denormalized_outputs = denormalized_outputs.cpu().numpy() pd = denormalized_outputs[0, 2, :, :][3:253, 131:381] pd[pd < 0.19] = 0 t1 = denormalized_outputs[0, 0, :, :][3:253, 131:381] t1[pd == 0] = 0 t2 = denormalized_outputs[0, 1, :, :][3:253, 131:381] t2[pd == 0] = 0 t1_for_ph = t1 * 1e3 t2_for_ph = t2 * 1e3 PHANTOM_SIZE = (250, 250) OFFSET = np.array([0.0, 0.0, 0.0]) RESOLUTION = np.array([1.0, 1.0, 1.0]) zer = np.zeros(t2.shape) t2_s_for_ph = t2_for_ph.copy() # Имя файла фантома по шаблону phantom_YYYYMMDD_HHMMSS_microsec.h5 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') phantom_fname = f'phantom_{timestamp}.h5' create_h5_file([pd, t1_for_ph, t2_for_ph, t2_s_for_ph, zer], phantom_fname, 'flow_model/phantoms_h5', OFFSET, RESOLUTION, PHANTOM_SIZE) # В режиме загрузки (upload) генерируем только один фантом, # чтобы избежать создания нескольких файлов при проверке пайплайна. if mode == 'upload': break if __name__ == "__main__": downloaded_dataset = [] flow_model(downloaded_dataset, mode='base')