split_gradient_at.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from copy import deepcopy
  2. from types import SimpleNamespace
  3. from typing import Tuple, Union
  4. import numpy as np
  5. from seqgen.pypulseq import eps
  6. from seqgen.pypulseq.make_extended_trapezoid import make_extended_trapezoid
  7. from seqgen.pypulseq.opts import Opts
  8. def split_gradient_at(
  9. grad: SimpleNamespace, time_point: float, system: Opts = Opts()
  10. ) -> Union[SimpleNamespace, Tuple[SimpleNamespace, SimpleNamespace]]:
  11. """
  12. Splits a trapezoidal gradient into two extended trapezoids defined by the cut line. Returns the two gradient parts
  13. by cutting the original 'grad' at 'time_point'. For the input type 'trapezoid' the results are returned as extended
  14. trapezoids, for 'arb' as arbitrary gradient objects. The delays in the individual gradient events are adapted such
  15. that add_gradients(...) produces a gradient equivalent to 'grad'.
  16. See also:
  17. - `pypulseq.split_gradient()`
  18. - `pypulseq.make_extended_trapezoid()`
  19. - `pypulseq.make_trapezoid()`
  20. - `pypulseq.Sequence.sequence.Sequence.add_block()`
  21. - `pypulseq.opts.Opts`
  22. Parameters
  23. ----------
  24. grad : SimpleNamespace
  25. Gradient event to be split into two gradient events.
  26. time_point : float
  27. Time point at which `grad` will be split into two gradient waveforms.
  28. system : Opts, default=Opts()
  29. System limits.
  30. Returns
  31. -------
  32. grad1, grad2 : SimpleNamespace
  33. Gradient waveforms after splitting.
  34. Raises
  35. ------
  36. ValueError
  37. If non-gradient event is passed.
  38. """
  39. # copy() to emulate pass-by-value; otherwise passed grad is modified
  40. grad = deepcopy(grad)
  41. grad_raster_time = system.grad_raster_time
  42. time_index = np.round(time_point / grad_raster_time)
  43. # Work around floating-point arithmetic limitation
  44. time_point = np.round(time_index * grad_raster_time, 6)
  45. channel = grad.channel
  46. if grad.type == "grad":
  47. # Check if we have an arbitrary gradient or an extended trapezoid
  48. if np.abs(grad.tt[-1] - 0.5 * grad_raster_time) < 1e-10 and np.all(
  49. np.abs(grad.tt[1:] - grad.tt[:-1] - grad_raster_time) < 1e-10
  50. ):
  51. # Arbitrary gradient -- trivial conversion
  52. # If time point is out of range we have nothing to do
  53. if time_index == 0 or time_index >= len(grad.tt):
  54. return grad
  55. else:
  56. grad1 = grad
  57. grad2 = grad
  58. grad1.last = 0.5 * (
  59. grad.waveform[time_index - 1] + grad.waveform[time_index]
  60. )
  61. grad2.first = grad1.last
  62. grad2.delay = grad.delay + grad.t[time_index]
  63. grad1.t = grad.t[:time_index]
  64. grad1.waveform = grad.waveform[:time_index]
  65. grad2.t = grad.t[time_index:] - time_point
  66. grad2.waveform = grad.waveform[time_index:]
  67. return grad1, grad2
  68. else:
  69. # Extended trapezoid
  70. times = grad.tt
  71. amplitudes = grad.waveform
  72. elif grad.type == "trap":
  73. grad.delay = np.round(grad.delay / grad_raster_time) * grad_raster_time
  74. grad.rise_time = np.round(grad.rise_time / grad_raster_time) * grad_raster_time
  75. grad.flat_time = np.round(grad.flat_time / grad_raster_time) * grad_raster_time
  76. grad.fall_time = np.round(grad.fall_time / grad_raster_time) * grad_raster_time
  77. # Prepare the extended trapezoid structure
  78. if grad.flat_time == 0:
  79. times = [0, grad.rise_time, grad.rise_time + grad.fall_time]
  80. amplitudes = [0, grad.amplitude, 0]
  81. else:
  82. times = [
  83. 0,
  84. grad.rise_time,
  85. grad.rise_time + grad.flat_time,
  86. grad.rise_time + grad.flat_time + grad.fall_time,
  87. ]
  88. amplitudes = [0, grad.amplitude, grad.amplitude, 0]
  89. else:
  90. raise ValueError("Splitting of unsupported event.")
  91. # If the split line is behind the gradient, there is no second gradient to create
  92. if time_point >= grad.delay + times[-1]:
  93. raise ValueError(
  94. "Splitting of gradient at time point after the end of gradient."
  95. )
  96. # If the split line goes through the delay
  97. if time_point < grad.delay:
  98. times = np.insert(grad.delay + times, 0, 0)
  99. amplitudes = [0, amplitudes]
  100. grad.delay = 0
  101. else:
  102. time_point -= grad.delay
  103. amplitudes = np.array(amplitudes)
  104. times = np.array(times).round(6) # Work around floating-point arithmetic limitation
  105. # Sample at time point
  106. amp_tp = np.interp(x=time_point, xp=times, fp=amplitudes)
  107. t_eps = 1e-10
  108. times1 = np.append(times[np.where(times < time_point - t_eps)], time_point)
  109. amplitudes1 = np.append(amplitudes[np.where(times < time_point - t_eps)], amp_tp)
  110. times2 = np.insert(times[times > time_point + t_eps], 0, time_point) - time_point
  111. amplitudes2 = np.insert(amplitudes[times > time_point + t_eps], 0, amp_tp)
  112. # Recreate gradients
  113. grad1 = make_extended_trapezoid(
  114. channel=channel,
  115. system=system,
  116. times=times1,
  117. amplitudes=amplitudes1,
  118. skip_check=True,
  119. )
  120. grad1.delay = grad.delay
  121. grad2 = make_extended_trapezoid(
  122. channel=channel,
  123. system=system,
  124. times=times2,
  125. amplitudes=amplitudes2,
  126. skip_check=True,
  127. )
  128. grad2.delay = time_point
  129. return grad1, grad2