from typing import Tuple import torch import torch.nn as nn device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #def denormalize_per_image(img_tensor, min_vals, max_vals): #min_t = torch.tensor(min_vals, dtype=img_tensor.dtype, device=img_tensor.device).view(3, 1, 1) #max_t = torch.tensor(max_vals, dtype=img_tensor.dtype, device=img_tensor.device).view(3, 1, 1) #img_denorm = img_tensor * (max_t - min_t) + min_t #return img_denorm def denormalize_per_image(img_tensor, min_vals, max_vals): img_denorm = torch.zeros_like(img_tensor) img_tensor_c = img_tensor.clone() img_denorm[:, 0] = img_tensor_c[:, 0] * (max_vals[0] - min_vals[0]) + min_vals[0] img_denorm[:, 1] = img_tensor_c[:, 1] * (max_vals[1] - min_vals[1]) + min_vals[1] img_denorm[:, 2] = img_tensor_c[:, 2] * (max_vals[2] - min_vals[2]) + min_vals[2] return img_denorm class MRSignalModel(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, q_map, max_vals_img): t1_t2 = q_map[:, 0:2, :, :] t1 = t1_t2[:, 0:1, :, :] t2 = t1_t2[:, 1:2, :, :] pd = q_map[:, 2:3, :, :] TE_t1 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device) TR_t1 = torch.tensor(0.7).unsqueeze(-1).unsqueeze(-1).to(device) ESP_t1 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device) TE_t2 = torch.tensor(0.0765).unsqueeze(-1).unsqueeze(-1).to(device) TR_t2 = torch.tensor(6.0).unsqueeze(-1).unsqueeze(-1).to(device) ESP_t2 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device) TE_pd = torch.tensor(0.00669).unsqueeze(-1).unsqueeze(-1).to(device) TR_pd = torch.tensor(6.0).unsqueeze(-1).unsqueeze(-1).to(device) ESP_pd = torch.tensor(0.00669).unsqueeze(-1).unsqueeze(-1).to(device) eps = 1e-8 denom_t2 = t2 + eps denom_t1 = t1 + eps etl = 10.0 t1_w = torch.exp(-TE_t1 / denom_t2) * (1 - torch.exp(-(TR_t1 - etl * ESP_t1) / denom_t1)) t2_w = torch.exp(-TE_t2 / denom_t2) * (1 - torch.exp(-(TR_t2 - etl * ESP_t2) / denom_t1)) pd_w = torch.exp(-TE_pd / denom_t2) * (1 - torch.exp(-(TR_pd - etl * ESP_pd) / denom_t1)) #return ( #(torch.exp(log_pd) #* t1_w), #(torch.exp(log_pd) #* t2_w), #(torch.exp(log_pd) #* pd_w) # 1600 #55000 #) t1_weigh = (pd * t1_w) * 1600 #1500 #max_vals_img[:, 0] #1550 #55000 t2_weigh = (pd * t2_w) * 1600 #1600 #max_vals_img[:, 1] #1660 #55000 pd_weigh = (pd * pd_w) * 1600 #1600 #max_vals_img[:, 2] #2060 #55000 return torch.cat((t1_weigh, t2_weigh, pd_weigh), 1)