rotate.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from types import SimpleNamespace
  2. from typing import List
  3. import numpy as np
  4. from seqgen.pypulseq.add_gradients import add_gradients
  5. from seqgen.pypulseq.scale_grad import scale_grad
  6. def __get_grad_abs_mag(grad: SimpleNamespace) -> np.ndarray:
  7. if grad.type == "trap":
  8. return np.abs(grad.amplitude)
  9. return np.max(np.abs(grad.waveform))
  10. def rotate(
  11. *args: SimpleNamespace,
  12. angle: float,
  13. axis: str,
  14. ) -> List[SimpleNamespace]:
  15. """
  16. Rotates the corresponding gradient(s) about the given axis by the specified amount. Gradients parallel to the
  17. rotation axis and non-gradient(s) are not affected. Possible rotation axes are 'x', 'y' or 'z'.
  18. See also `pypulseq.Sequence.sequence.add_block()`.
  19. Parameters
  20. ----------
  21. axis : str
  22. Axis about which the gradient(s) will be rotated.
  23. angle : float
  24. Angle by which the gradient(s) will be rotated.
  25. args : SimpleNamespace
  26. Gradient(s).
  27. Returns
  28. -------
  29. rotated_grads : [SimpleNamespace]
  30. Rotated gradient(s).
  31. """
  32. axes = ["x", "y", "z"]
  33. # Cycle through the objects and rotate gradients non-parallel to the given rotation axis. Rotated gradients
  34. # assigned to the same axis are then added together.
  35. # First create indexes of the objects to be bypassed or rotated
  36. i_rotate1 = []
  37. i_rotate2 = []
  38. i_bypass = []
  39. axes.remove(axis)
  40. axes_to_rotate = axes
  41. if len(axes_to_rotate) != 2:
  42. raise ValueError("Incorrect axes specification.")
  43. for i in range(len(args)):
  44. event = args[i]
  45. if (event.type != "grad" and event.type != "trap") or event.channel == axis:
  46. i_bypass.append(i)
  47. else:
  48. if event.channel == axes_to_rotate[0]:
  49. i_rotate1.append(i)
  50. else:
  51. if event.channel == axes_to_rotate[1]:
  52. i_rotate2.append(i)
  53. else:
  54. i_bypass.append(i) # Should never happen
  55. # Now every gradient to be rotated generates two new gradients: one on the original axis and one on the other from
  56. # the axes_to_rotate list
  57. rotated1 = []
  58. rotated2 = []
  59. max_mag = 0 # Measure of relevant amplitude
  60. for i in range(len(i_rotate1)):
  61. g = args[i_rotate1[i]]
  62. max_mag = np.max((max_mag, __get_grad_abs_mag(g)))
  63. rotated1.append(scale_grad(grad=g, scale=np.cos(angle)))
  64. g = scale_grad(grad=g, scale=np.sin(angle))
  65. g.channel = axes_to_rotate[1]
  66. rotated2.append(g)
  67. for i in range(len(i_rotate2)):
  68. g = args[i_rotate2[i]]
  69. max_mag = np.max((max_mag, __get_grad_abs_mag(g)))
  70. rotated2.append(scale_grad(grad=g, scale=np.cos(angle)))
  71. g = scale_grad(grad=g, scale=-np.sin(angle))
  72. g.channel = axes_to_rotate[1]
  73. rotated1.append(g)
  74. # Eliminate zero-amplitude gradients
  75. threshold = 1e-6 * max_mag
  76. for i in range(len(rotated1) - 1, -1, -1):
  77. if __get_grad_abs_mag(rotated1[i]) < threshold:
  78. rotated1.pop(i)
  79. for i in range(len(rotated2) - 1, -1, -1):
  80. if __get_grad_abs_mag(rotated2[i]) < threshold:
  81. rotated2.pop(i)
  82. # Add gradients on the corresponding axis together
  83. g = []
  84. if len(rotated1) > 1:
  85. g.append(add_gradients(grads=rotated1))
  86. else:
  87. if len(rotated1) != 0:
  88. g.append(rotated1[0])
  89. if len(rotated2) > 1:
  90. g.append(add_gradients(grads=rotated2))
  91. else:
  92. if len(rotated2) != 0:
  93. g.append(rotated2[0])
  94. # Eliminate zero amplitude gradients
  95. for i in range(len(g) - 1, -1, -1):
  96. if __get_grad_abs_mag(g[i]) < threshold:
  97. g.pop(i)
  98. # Export
  99. bypass = np.take(args, i_bypass)
  100. rotated_grads = [*bypass, *g]
  101. return rotated_grads