block.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. from types import SimpleNamespace
  2. from typing import Tuple, List, Union
  3. import numpy as np
  4. from seqgen.pypulseq.block_to_events import block_to_events
  5. from seqgen.pypulseq.compress_shape import compress_shape
  6. from seqgen.pypulseq.decompress_shape import decompress_shape
  7. from seqgen.pypulseq.event_lib import EventLibrary
  8. from seqgen.pypulseq.supported_labels_rf_use import get_supported_labels
  9. def set_block(self, block_index: int, *args: SimpleNamespace) -> None:
  10. """
  11. Replace block at index with new block provided as block structure, add sequence block, or create a new block
  12. from events and store at position specified by index. The block or events are provided in uncompressed form and
  13. will be stored in the compressed, non-redundant internal libraries.
  14. See also:
  15. - `pypulseq.Sequence.sequence.Sequence.get_block()`
  16. - `pypulseq.Sequence.sequence.Sequence.add_block()`
  17. Parameters
  18. ----------
  19. block_index : int
  20. Index at which block is replaced.
  21. args : SimpleNamespace
  22. Block or events to be replaced/added or created at `block_index`.
  23. Raises
  24. ------
  25. ValueError
  26. If trigger event that is passed is of unsupported control event type.
  27. If delay is set for a gradient even that starts with a non-zero amplitude.
  28. RuntimeError
  29. If two consecutive gradients to not have the same amplitude at the connection point.
  30. If the first gradient in the block does not start with 0.
  31. If a gradient that doesn't end at zero is not aligned to the block boundary.
  32. """
  33. events = block_to_events(*args)
  34. self.block_events[block_index] = np.zeros(7, dtype=np.int32)
  35. duration = 0
  36. check_g = {} # Key-value mapping of index and pairs of gradients/times
  37. extensions = []
  38. for event in events:
  39. if not isinstance(event, float): # If event is not a block duration
  40. if event.type == "rf":
  41. if hasattr(event, "id"):
  42. rf_id = event.id
  43. else:
  44. rf_id, _ = register_rf_event(self, event)
  45. self.block_events[block_index][1] = rf_id
  46. duration = max(
  47. duration, event.shape_dur + event.delay + event.ringdown_time
  48. )
  49. elif event.type == "grad":
  50. channel_num = ["x", "y", "z"].index(event.channel)
  51. idx = 2 + channel_num
  52. grad_start = (
  53. event.delay
  54. + np.floor(event.tt[0] / self.grad_raster_time + 1e-10)
  55. * self.grad_raster_time
  56. )
  57. grad_duration = (
  58. event.delay
  59. + np.ceil(event.tt[-1] / self.grad_raster_time - 1e-10)
  60. * self.grad_raster_time
  61. )
  62. check_g[channel_num] = SimpleNamespace()
  63. check_g[channel_num].idx = idx
  64. check_g[channel_num].start = np.array((grad_start, event.first))
  65. check_g[channel_num].stop = np.array((grad_duration, event.last))
  66. if hasattr(event, "id"):
  67. grad_id = event.id
  68. else:
  69. grad_id, _ = register_grad_event(self, event)
  70. self.block_events[block_index][idx] = grad_id
  71. duration = np.max([duration, grad_duration])
  72. elif event.type == "trap":
  73. channel_num = ["x", "y", "z"].index(event.channel)
  74. idx = 2 + channel_num
  75. check_g[channel_num] = SimpleNamespace()
  76. check_g[channel_num].idx = idx
  77. check_g[channel_num].start = np.array((0, 0))
  78. check_g[channel_num].stop = np.array(
  79. (
  80. event.delay
  81. + event.rise_time
  82. + event.fall_time
  83. + event.flat_time,
  84. 0,
  85. )
  86. )
  87. if hasattr(event, "id"):
  88. trap_id = event.id
  89. else:
  90. trap_id = register_grad_event(self, event)
  91. self.block_events[block_index][idx] = trap_id
  92. duration = np.max(
  93. [
  94. duration,
  95. event.delay
  96. + event.rise_time
  97. + event.flat_time
  98. + event.fall_time,
  99. ]
  100. )
  101. elif event.type == "adc":
  102. if hasattr(event, "id"):
  103. adc_id = event.id
  104. else:
  105. adc_id = register_adc_event(self, event)
  106. self.block_events[block_index][5] = adc_id
  107. duration = np.max(
  108. [
  109. duration,
  110. event.delay + event.num_samples * event.dwell + event.dead_time,
  111. ]
  112. )
  113. elif event.type == "delay":
  114. duration = np.max([duration, event.delay])
  115. elif event.type in ["output", "trigger"]:
  116. if hasattr(event, "id"):
  117. event_id = event.id
  118. else:
  119. event_id = register_control_event(self, event)
  120. ext = {"type": self.get_extension_type_ID("TRIGGERS"), "ref": event_id}
  121. extensions.append(ext)
  122. duration = np.max([duration, event.delay + event.duration])
  123. elif event.type in ["labelset", "labelinc"]:
  124. if hasattr(event, "id"):
  125. label_id = event.id
  126. else:
  127. label_id = register_label_event(self, event)
  128. ext = {
  129. "type": self.get_extension_type_ID(event.type.upper()),
  130. "ref": label_id,
  131. }
  132. extensions.append(ext)
  133. # =========
  134. # ADD EXTENSIONS
  135. # =========
  136. if len(extensions) > 0:
  137. """
  138. Add extensions now... but it's tricky actually we need to check whether the exactly the same list of extensions
  139. already exists, otherwise we have to create a new one... ooops, we have a potential problem with the key
  140. mapping then... The trick is that we rely on the sorting of the extension IDs and then we can always find the
  141. last one in the list by setting the reference to the next to 0 and then proceed with the other elements.
  142. """
  143. sort_idx = np.argsort([e["ref"] for e in extensions])
  144. extensions = np.take(extensions, sort_idx)
  145. all_found = True
  146. extension_id = 0
  147. for i in range(len(extensions)):
  148. data = [extensions[i]["type"], extensions[i]["ref"], extension_id]
  149. extension_id, found = self.extensions_library.find(data)
  150. all_found = all_found and found
  151. if not found:
  152. break
  153. if not all_found:
  154. # Add the list
  155. extension_id = 0
  156. for i in range(len(extensions)):
  157. data = [extensions[i]["type"], extensions[i]["ref"], extension_id]
  158. extension_id, found = self.extensions_library.find(data)
  159. if not found:
  160. self.extensions_library.insert(extension_id, data)
  161. # Now we add the ID
  162. self.block_events[block_index][6] = extension_id
  163. # =========
  164. # PERFORM GRADIENT CHECKS
  165. # =========
  166. for grad_to_check in check_g.values():
  167. if (
  168. abs(grad_to_check.start[1])
  169. > self.system.max_slew * self.system.grad_raster_time
  170. ):
  171. if grad_to_check.start[0] != 0:
  172. raise ValueError(
  173. "No delay allowed for gradients which start with a non-zero amplitude"
  174. )
  175. if block_index > 1:
  176. prev_id = self.block_events[block_index - 1][grad_to_check.idx]
  177. if prev_id != 0:
  178. prev_lib = self.grad_library.get(prev_id)
  179. prev_data = prev_lib["data"]
  180. prev_type = prev_lib["type"]
  181. if prev_type == "t":
  182. raise RuntimeError(
  183. "Two consecutive gradients need to have the same amplitude at the connection point"
  184. )
  185. elif prev_type == "g":
  186. last = prev_data[5]
  187. if (
  188. abs(last - grad_to_check.start[1])
  189. > self.system.max_slew * self.system.grad_raster_time
  190. ):
  191. raise RuntimeError(
  192. "Two consecutive gradients need to have the same amplitude at the connection point"
  193. )
  194. else:
  195. raise RuntimeError(
  196. "First gradient in the the first block has to start at 0."
  197. )
  198. if (
  199. grad_to_check.stop[1] > self.system.max_slew * self.system.grad_raster_time
  200. and abs(grad_to_check.stop[0] - duration) > 1e-7
  201. ):
  202. raise RuntimeError(
  203. "A gradient that doesn't end at zero needs to be aligned to the block boundary."
  204. )
  205. self.block_durations.append(float(duration))
  206. def get_block(self, block_index: int) -> SimpleNamespace:
  207. """
  208. Returns PyPulseq block at `block_index` position in `self.block_events`.
  209. The block is created from the sequence data with all events and shapes decompressed.
  210. Parameters
  211. ----------
  212. block_index : int
  213. Index of PyPulseq block to be retrieved from `self.block_events`.
  214. Returns
  215. -------
  216. block : SimpleNamespace
  217. PyPulseq block at 'block_index' position in `self.block_events`.
  218. Raises
  219. ------
  220. ValueError
  221. If a trigger event of an unsupported control type is encountered.
  222. If a label object of an unknown extension ID is encountered.
  223. """
  224. block = SimpleNamespace()
  225. attrs = ["block_duration", "rf", "gx", "gy", "gz", "adc"]
  226. values = [None] * len(attrs)
  227. for att, val in zip(attrs, values):
  228. setattr(block, att, val)
  229. event_ind = self.block_events[block_index]
  230. if event_ind[0] > 0: # Delay
  231. delay = SimpleNamespace()
  232. delay.type = "delay"
  233. delay.delay = self.delay_library.data[event_ind[0]][0]
  234. block.delay = delay
  235. if event_ind[1] > 0: # RF
  236. if len(self.rf_library.type) >= event_ind[1]:
  237. block.rf = self.rf_from_lib_data(
  238. self.rf_library.data[event_ind[1]], self.rf_library.type[event_ind[1]]
  239. )
  240. else:
  241. block.rf = self.rf_from_lib_data(
  242. self.rf_library.data[event_ind[1]]
  243. ) # Undefined type/use
  244. # Gradients
  245. grad_channels = ["gx", "gy", "gz"]
  246. for i in range(len(grad_channels)):
  247. if event_ind[2 + i] > 0:
  248. grad, compressed = SimpleNamespace(), SimpleNamespace()
  249. grad_type = self.grad_library.type[event_ind[2 + i]]
  250. lib_data = self.grad_library.data[event_ind[2 + i]]
  251. grad.type = "trap" if grad_type == "t" else "grad"
  252. grad.channel = grad_channels[i][1]
  253. if grad.type == "grad":
  254. amplitude = lib_data[0]
  255. shape_id = lib_data[1]
  256. time_id = lib_data[2]
  257. delay = lib_data[3]
  258. shape_data = self.shape_library.data[shape_id]
  259. compressed.num_samples = shape_data[0]
  260. compressed.data = shape_data[1:]
  261. g = decompress_shape(compressed)
  262. grad.waveform = amplitude * g
  263. if time_id == 0:
  264. grad.tt = (np.arange(1, len(g) + 1) - 0.5) * self.grad_raster_time
  265. t_end = len(g) * self.grad_raster_time
  266. else:
  267. t_shape_data = self.shape_library.data[time_id]
  268. compressed.num_samples = t_shape_data[0]
  269. compressed.data = t_shape_data[1:]
  270. grad.tt = decompress_shape(compressed) * self.grad_raster_time
  271. assert len(grad.waveform) == len(grad.tt)
  272. t_end = grad.tt[-1]
  273. grad.shape_id = shape_id
  274. grad.time_id = time_id
  275. grad.delay = delay
  276. grad.shape_dur = t_end
  277. if len(lib_data) > 5:
  278. grad.first = lib_data[4]
  279. grad.last = lib_data[5]
  280. else:
  281. grad.amplitude = lib_data[0]
  282. grad.rise_time = lib_data[1]
  283. grad.flat_time = lib_data[2]
  284. grad.fall_time = lib_data[3]
  285. grad.delay = lib_data[4]
  286. grad.area = grad.amplitude * (
  287. grad.flat_time + grad.rise_time / 2 + grad.fall_time / 2
  288. )
  289. grad.flat_area = grad.amplitude * grad.flat_time
  290. setattr(block, grad_channels[i], grad)
  291. # ADC
  292. if event_ind[5] > 0:
  293. lib_data = self.adc_library.data[event_ind[5]]
  294. if len(lib_data) < 6:
  295. lib_data = np.append(lib_data, 0)
  296. adc = SimpleNamespace()
  297. (
  298. adc.num_samples,
  299. adc.dwell,
  300. adc.delay,
  301. adc.freq_offset,
  302. adc.phase_offset,
  303. adc.dead_time,
  304. ) = [lib_data[x] for x in range(6)]
  305. adc.num_samples = int(adc.num_samples)
  306. adc.type = "adc"
  307. block.adc = adc
  308. # Triggers
  309. if event_ind[6] > 0:
  310. # We have extensions - triggers, labels, etc.
  311. next_ext_id = event_ind[6]
  312. while next_ext_id != 0:
  313. ext_data = self.extensions_library.data[next_ext_id]
  314. # Format: ext_type, ext_id, next_ext_id
  315. ext_type = self.get_extension_type_string(ext_data[0])
  316. if ext_type == "TRIGGERS":
  317. trigger_types = ["output", "trigger"]
  318. data = self.trigger_library.data[ext_data[1]]
  319. trigger = SimpleNamespace()
  320. trigger.type = trigger_types[int(data[0]) - 1]
  321. if data[0] == 1:
  322. trigger_channels = ["osc0", "osc1", "ext1"]
  323. trigger.channel = trigger_channels[int(data[1]) - 1]
  324. elif data[0] == 2:
  325. trigger_channels = ["physio1", "physio2"]
  326. trigger.channel = trigger_channels[int(data[1]) - 1]
  327. else:
  328. raise ValueError("Unsupported trigger event type")
  329. trigger.delay = data[2]
  330. trigger.duration = data[3]
  331. # Allow for multiple triggers per block
  332. if hasattr(block, "trigger"):
  333. block.trigger[len(block.trigger)] = trigger
  334. else:
  335. block.trigger = {0: trigger}
  336. elif ext_type in ["LABELSET", "LABELINC"]:
  337. label = SimpleNamespace()
  338. label.type = ext_type.lower()
  339. supported_labels = get_supported_labels()
  340. if ext_type == "LABELSET":
  341. data = self.label_set_library.data[ext_data[1]]
  342. else:
  343. data = self.label_inc_library.data[ext_data[1]]
  344. label.label = supported_labels[int(data[1] - 1)]
  345. label.value = data[0]
  346. # Allow for multiple labels per block
  347. if hasattr(block, "label"):
  348. block.label[len(block.label)] = label
  349. else:
  350. block.label = {0: label}
  351. else:
  352. raise RuntimeError(f"Unknown extension ID {ext_data[0]}")
  353. next_ext_id = ext_data[2]
  354. block.block_duration = self.block_durations[block_index - 1]
  355. return block
  356. def register_adc_event(self, event: EventLibrary) -> int:
  357. """
  358. Parameters
  359. ----------
  360. event : SimpleNamespace
  361. ADC event to be registered.
  362. Returns
  363. -------
  364. int
  365. ID of registered ADC event.
  366. """
  367. data = np.array(
  368. [
  369. event.num_samples,
  370. event.dwell,
  371. np.max([event.delay, event.dead_time]),
  372. event.freq_offset,
  373. event.phase_offset,
  374. event.dead_time,
  375. ]
  376. )
  377. adc_id, _ = self.adc_library.find_or_insert(new_data=data)
  378. return adc_id
  379. def register_control_event(self, event: SimpleNamespace) -> int:
  380. """
  381. Parameters
  382. ----------
  383. event : SimpleNamespace
  384. Control event to be registered.
  385. Returns
  386. -------
  387. int
  388. ID of registered control event.
  389. """
  390. event_type = ["output", "trigger"].index(event.type)
  391. if event_type == 0:
  392. # Trigger codes supported by the Siemens interpreter as of May 2019
  393. event_channel = ["osc0", "osc1", "ext1"].index(event.channel)
  394. elif event_type == 1:
  395. # Trigger codes supported by the Siemens interpreter as of June 2019
  396. event_channel = ["physio1", "physio2"].index(event.channel)
  397. else:
  398. raise ValueError("Unsupported control event type")
  399. data = [event_type + 1, event_channel + 1, event.delay, event.duration]
  400. control_id, _ = self.trigger_library.find_or_insert(new_data=data)
  401. return control_id
  402. def register_grad_event(
  403. self, event: SimpleNamespace
  404. ) -> Union[int, Tuple[int, List[int]]]:
  405. """
  406. Parameters
  407. ----------
  408. event : SimpleNamespace
  409. Gradient event to be registered.
  410. Returns
  411. -------
  412. int, [int, ...]
  413. For gradient events: ID of registered gradient event, list of shape IDs
  414. int
  415. For trapezoid gradient events: ID of registered gradient event
  416. """
  417. may_exist = True
  418. if event.type == "grad":
  419. amplitude = np.abs(event.waveform).max()
  420. if amplitude > 0:
  421. fnz = event.waveform[np.nonzero(event.waveform)[0][0]]
  422. amplitude *= (
  423. np.sign(fnz) if fnz != 0 else 1
  424. ) # Workaround for np.sign(0) = 0
  425. if hasattr(event, "shape_IDs"):
  426. shape_IDs = event.shape_IDs
  427. else:
  428. shape_IDs = [0, 0]
  429. if amplitude != 0:
  430. g = event.waveform / amplitude
  431. else:
  432. g = event.waveform
  433. c_shape = compress_shape(g)
  434. s_data = np.insert(c_shape.data, 0, c_shape.num_samples)
  435. shape_IDs[0], found = self.shape_library.find_or_insert(s_data)
  436. may_exist = may_exist & found
  437. c_time = compress_shape(event.tt / self.grad_raster_time)
  438. if not (
  439. len(c_time.data) == 4
  440. and np.all(c_time.data == [0.5, 1, 1, c_time.num_samples - 3])
  441. ):
  442. t_data = np.insert(c_time.data, 0, c_time.num_samples)
  443. shape_IDs[1], found = self.shape_library.find_or_insert(t_data)
  444. may_exist = may_exist & found
  445. data = [amplitude, *shape_IDs, event.delay, event.first, event.last]
  446. elif event.type == "trap":
  447. data = np.array(
  448. [
  449. event.amplitude,
  450. event.rise_time,
  451. event.flat_time,
  452. event.fall_time,
  453. event.delay,
  454. ]
  455. )
  456. else:
  457. raise ValueError("Unknown gradient type passed to register_grad_event()")
  458. if may_exist:
  459. grad_id, _ = self.grad_library.find_or_insert(
  460. new_data=data, data_type=event.type[0]
  461. )
  462. else:
  463. grad_id = self.grad_library.insert(0, data, event.type[0])
  464. if event.type == "grad":
  465. return grad_id, shape_IDs
  466. elif event.type == "trap":
  467. return grad_id
  468. def register_label_event(self, event: SimpleNamespace) -> int:
  469. """
  470. Parameters
  471. ----------
  472. event : SimpleNamespace
  473. ID of label event to be registered.
  474. Returns
  475. -------
  476. int
  477. ID of registered label event.
  478. """
  479. label_id = get_supported_labels().index(event.label) + 1
  480. data = [event.value, label_id]
  481. if event.type == "labelset":
  482. label_id, _ = self.label_set_library.find_or_insert(new_data=data)
  483. elif event.type == "labelinc":
  484. label_id, _ = self.label_inc_library.find_or_insert(new_data=data)
  485. else:
  486. raise ValueError("Unsupported label type passed to register_label_event()")
  487. return label_id
  488. def register_rf_event(self, event: SimpleNamespace) -> Tuple[int, List[int]]:
  489. """
  490. Parameters
  491. ----------
  492. event : SimpleNamespace
  493. RF event to be registered.
  494. Returns
  495. -------
  496. int, [int, ...]
  497. ID of registered RF event, list of shape IDs
  498. """
  499. mag = np.abs(event.signal)
  500. amplitude = np.max(mag)
  501. mag /= amplitude
  502. # Following line of code is a workaround for numpy's divide functions returning NaN when mathematical
  503. # edge cases are encountered (eg. divide by 0)
  504. mag[np.isnan(mag)] = 0
  505. phase = np.angle(event.signal)
  506. phase[phase < 0] += 2 * np.pi
  507. phase /= 2 * np.pi
  508. may_exist = True
  509. if hasattr(event, "shape_IDs"):
  510. shape_IDs = event.shape_IDs
  511. else:
  512. shape_IDs = [0, 0, 0]
  513. mag_shape = compress_shape(mag)
  514. data = np.insert(mag_shape.data, 0, mag_shape.num_samples)
  515. shape_IDs[0], found = self.shape_library.find_or_insert(data)
  516. may_exist = may_exist & found
  517. phase_shape = compress_shape(phase)
  518. data = np.insert(phase_shape.data, 0, phase_shape.num_samples)
  519. shape_IDs[1], found = self.shape_library.find_or_insert(data)
  520. may_exist = may_exist & found
  521. time_shape = compress_shape(
  522. event.t / self.rf_raster_time
  523. ) # Time shape is stored in units of RF raster
  524. if len(time_shape.data) == 4 and np.all(
  525. time_shape.data == [0.5, 1, 1, time_shape.num_samples - 3]
  526. ):
  527. shape_IDs[2] = 0
  528. else:
  529. data = [time_shape.num_samples, *time_shape.data]
  530. shape_IDs[2], found = self.shape_library.find_or_insert(data)
  531. may_exist = may_exist & found
  532. use = "u" # Undefined
  533. if hasattr(event, "use"):
  534. if event.use in [
  535. "excitation",
  536. "refocusing",
  537. "inversion",
  538. "saturation",
  539. "preparation",
  540. ]:
  541. use = event.use[0]
  542. else:
  543. use = "u"
  544. data = np.array(
  545. [amplitude, *shape_IDs, event.delay, event.freq_offset, event.phase_offset]
  546. )
  547. if may_exist:
  548. rf_id, _ = self.rf_library.find_or_insert(new_data=data, data_type=use)
  549. else:
  550. rf_id = self.rf_library.insert(key_id=0, new_data=data, data_type=use)
  551. return rf_id, shape_IDs