| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- import numpy as np
- import os
- import torch
- from torchvision.utils import save_image
- from torch.utils.data import DataLoader
- from flow_model.model.losses import VGGLoss
- from flow_model.model.network.hf import HierarchyFlow
- from flow_model.model.utils.dataset import get_dataset
- from flow_model.model.utils.sampler import DistributedGivenIterationSampler, DistributedTestSampler
- from tensorboardX import SummaryWriter
- import logging
- from flow_model.model.utils.log_helper import init_log
- from flow_model.model.trainers.MR_signal_model import *
- import yaml
- from easydict import EasyDict
- import matplotlib.pyplot as plt
- import h5py
- from cmap import Colormap
- def transform(image):
- max_val = np.max(image)
- min_val = np.min(image)
- scale = 2.0 / (max_val - min_val + 1e-8) # Добавляем небольшое значение для избежания деления на ноль
- offset = -1.0 - min_val * scale
- # Применяем масштабирование
- image_scaled = image * scale + offset
- return torch.from_numpy(image_scaled).unsqueeze(0).unsqueeze(0)
- def load_checkpoint(checkpoint_fpath, model, optimizer, location):
- checkpoint = torch.load(checkpoint_fpath, map_location=location)
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- return model, optimizer, checkpoint['step']
- def denormalize_per_image(img_tensor, min_vals, max_vals):
- img_denorm = torch.zeros_like(img_tensor)
- img_denorm[:, 0] = img_tensor[:, 0] * (max_vals[0] - min_vals[0]) + min_vals[0]
- img_denorm[:, 1] = img_tensor[:, 1] * (max_vals[1] - min_vals[1]) + min_vals[1]
- img_denorm[:, 2] = img_tensor[:, 2] * (max_vals[2] - min_vals[2]) + min_vals[2]
- return img_denorm
- def create_h5_file(list_of_phantoms, phantom_name, dir_for_save, offset, resolution, ph_size):
- offset = np.expand_dims(offset, 1)
- resolution = np.expand_dims(resolution, 1)
- rescaled_phantoms = []
- for i in range(len(list_of_phantoms)):
- if (i == 0) or (i == 4):
- rescaled_phantoms.append(list_of_phantoms[i])
- else:
- #r_ph = list_of_phantoms[i] * 1e3
- r_ph = list_of_phantoms[i]
- r_ph = np.where(r_ph > 0, 1/r_ph, 0)
- rescaled_phantoms.append(r_ph)
- five_phant = np.array(rescaled_phantoms).transpose(1,2,0)
- with h5py.File(f'{dir_for_save}/{phantom_name}','w') as f:
- grp = f.create_group("sample")
- dset = grp.create_dataset('data', (ph_size[0], ph_size[1], 5), dtype="f8")
- dset[:,:, :] = five_phant[:,:, :]
- dset_1 = grp.create_dataset('offset', (3, 1), dtype="f8")
- dset_1[:, :] = offset[:, :]
- dset_2 = grp.create_dataset('resolution', (3, 1), dtype="f8")
- dset_2[:, :] = resolution[:, :]
|