| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- from typing import Tuple
- import torch
- import torch.nn as nn
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- #def denormalize_per_image(img_tensor, min_vals, max_vals):
- #min_t = torch.tensor(min_vals, dtype=img_tensor.dtype, device=img_tensor.device).view(3, 1, 1)
- #max_t = torch.tensor(max_vals, dtype=img_tensor.dtype, device=img_tensor.device).view(3, 1, 1)
- #img_denorm = img_tensor * (max_t - min_t) + min_t
- #return img_denorm
- def denormalize_per_image(img_tensor, min_vals, max_vals):
- img_denorm = torch.zeros_like(img_tensor)
- img_tensor_c = img_tensor.clone()
- img_denorm[:, 0] = img_tensor_c[:, 0] * (max_vals[0] - min_vals[0]) + min_vals[0]
- img_denorm[:, 1] = img_tensor_c[:, 1] * (max_vals[1] - min_vals[1]) + min_vals[1]
- img_denorm[:, 2] = img_tensor_c[:, 2] * (max_vals[2] - min_vals[2]) + min_vals[2]
- return img_denorm
- class MRSignalModel(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- def forward(self, q_map, max_vals_img):
- t1_t2 = q_map[:, 0:2, :, :]
- t1 = t1_t2[:, 0:1, :, :]
- t2 = t1_t2[:, 1:2, :, :]
- pd = q_map[:, 2:3, :, :]
- TE_t1 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device)
- TR_t1 = torch.tensor(0.7).unsqueeze(-1).unsqueeze(-1).to(device)
- ESP_t1 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device)
- TE_t2 = torch.tensor(0.0765).unsqueeze(-1).unsqueeze(-1).to(device)
- TR_t2 = torch.tensor(6.0).unsqueeze(-1).unsqueeze(-1).to(device)
- ESP_t2 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device)
- TE_pd = torch.tensor(0.00669).unsqueeze(-1).unsqueeze(-1).to(device)
- TR_pd = torch.tensor(6.0).unsqueeze(-1).unsqueeze(-1).to(device)
- ESP_pd = torch.tensor(0.00669).unsqueeze(-1).unsqueeze(-1).to(device)
- eps = 1e-8
- denom_t2 = t2 + eps
- denom_t1 = t1 + eps
- etl = 10.0
- t1_w = torch.exp(-TE_t1 / denom_t2) * (1 - torch.exp(-(TR_t1 - etl * ESP_t1) / denom_t1))
- t2_w = torch.exp(-TE_t2 / denom_t2) * (1 - torch.exp(-(TR_t2 - etl * ESP_t2) / denom_t1))
- pd_w = torch.exp(-TE_pd / denom_t2) * (1 - torch.exp(-(TR_pd - etl * ESP_pd) / denom_t1))
- #return (
- #(torch.exp(log_pd)
- #* t1_w),
- #(torch.exp(log_pd)
- #* t2_w),
- #(torch.exp(log_pd)
- #* pd_w)
- # 1600 #55000
- #)
- t1_weigh = (pd * t1_w) * 1600 #1500 #max_vals_img[:, 0] #1550 #55000
- t2_weigh = (pd * t2_w) * 1600 #1600 #max_vals_img[:, 1] #1660 #55000
- pd_weigh = (pd * pd_w) * 1600 #1600 #max_vals_img[:, 2] #2060 #55000
- return torch.cat((t1_weigh, t2_weigh, pd_weigh), 1)
|