MR_signal_model.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from typing import Tuple
  2. import torch
  3. import torch.nn as nn
  4. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  5. #def denormalize_per_image(img_tensor, min_vals, max_vals):
  6. #min_t = torch.tensor(min_vals, dtype=img_tensor.dtype, device=img_tensor.device).view(3, 1, 1)
  7. #max_t = torch.tensor(max_vals, dtype=img_tensor.dtype, device=img_tensor.device).view(3, 1, 1)
  8. #img_denorm = img_tensor * (max_t - min_t) + min_t
  9. #return img_denorm
  10. def denormalize_per_image(img_tensor, min_vals, max_vals):
  11. img_denorm = torch.zeros_like(img_tensor)
  12. img_tensor_c = img_tensor.clone()
  13. img_denorm[:, 0] = img_tensor_c[:, 0] * (max_vals[0] - min_vals[0]) + min_vals[0]
  14. img_denorm[:, 1] = img_tensor_c[:, 1] * (max_vals[1] - min_vals[1]) + min_vals[1]
  15. img_denorm[:, 2] = img_tensor_c[:, 2] * (max_vals[2] - min_vals[2]) + min_vals[2]
  16. return img_denorm
  17. class MRSignalModel(nn.Module):
  18. def __init__(self) -> None:
  19. super().__init__()
  20. def forward(self, q_map, max_vals_img):
  21. t1_t2 = q_map[:, 0:2, :, :]
  22. t1 = t1_t2[:, 0:1, :, :]
  23. t2 = t1_t2[:, 1:2, :, :]
  24. pd = q_map[:, 2:3, :, :]
  25. TE_t1 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device)
  26. TR_t1 = torch.tensor(0.7).unsqueeze(-1).unsqueeze(-1).to(device)
  27. ESP_t1 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device)
  28. TE_t2 = torch.tensor(0.0765).unsqueeze(-1).unsqueeze(-1).to(device)
  29. TR_t2 = torch.tensor(6.0).unsqueeze(-1).unsqueeze(-1).to(device)
  30. ESP_t2 = torch.tensor(0.00765).unsqueeze(-1).unsqueeze(-1).to(device)
  31. TE_pd = torch.tensor(0.00669).unsqueeze(-1).unsqueeze(-1).to(device)
  32. TR_pd = torch.tensor(6.0).unsqueeze(-1).unsqueeze(-1).to(device)
  33. ESP_pd = torch.tensor(0.00669).unsqueeze(-1).unsqueeze(-1).to(device)
  34. eps = 1e-8
  35. denom_t2 = t2 + eps
  36. denom_t1 = t1 + eps
  37. etl = 10.0
  38. t1_w = torch.exp(-TE_t1 / denom_t2) * (1 - torch.exp(-(TR_t1 - etl * ESP_t1) / denom_t1))
  39. t2_w = torch.exp(-TE_t2 / denom_t2) * (1 - torch.exp(-(TR_t2 - etl * ESP_t2) / denom_t1))
  40. pd_w = torch.exp(-TE_pd / denom_t2) * (1 - torch.exp(-(TR_pd - etl * ESP_pd) / denom_t1))
  41. #return (
  42. #(torch.exp(log_pd)
  43. #* t1_w),
  44. #(torch.exp(log_pd)
  45. #* t2_w),
  46. #(torch.exp(log_pd)
  47. #* pd_w)
  48. # 1600 #55000
  49. #)
  50. t1_weigh = (pd * t1_w) * 1600 #1500 #max_vals_img[:, 0] #1550 #55000
  51. t2_weigh = (pd * t2_w) * 1600 #1600 #max_vals_img[:, 1] #1660 #55000
  52. pd_weigh = (pd * pd_w) * 1600 #1600 #max_vals_img[:, 2] #2060 #55000
  53. return torch.cat((t1_weigh, t2_weigh, pd_weigh), 1)