recon_jemris.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Converts the jemris simulation outputs (signals.h5 files) into data or save as .npy or .mat files
  2. # Gehua Tong
  3. # March 06, 2020
  4. import h5py
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. def recon_jemris(file, dims):
  8. """Reads JEMRIS's signals.h5 output, reconstructs it (Cartesian only for now) using the dimensions specified,
  9. and returns both the complex k-space and image matrix AND magnitude images
  10. Inputs
  11. ------
  12. file : str
  13. Path to signals.h5
  14. dims : array_like
  15. Dimensions for reconstruction
  16. [Nro], [Nro, Nline], or [Nro, Nline, Nslice]
  17. Returns
  18. -------
  19. kspace : np.ndarray
  20. Complex k-space
  21. imspace : np.ndarray
  22. Complex image space
  23. images : np.ndarray
  24. Real, channel-combined images
  25. """
  26. Mxy_out, M_vec_out, times_out = read_jemris_output(file)
  27. kspace, imspace = recon_jemris_output(Mxy_out, dims)
  28. images = save_recon_images(imspace)# TODO save as png (use previous code!)
  29. return kspace, imspace, images
  30. def read_jemris_output(file):
  31. """Reads and parses JEMRIS's signals.h5 output
  32. Inputs
  33. ------
  34. file : str
  35. Path to signals.h5
  36. Returns
  37. -------
  38. Mxy_out : np.ndarray
  39. Complex representation of transverse magnetization sampled during readout
  40. Matrix dimensions : (total # readouts) x (# channels)
  41. M_vec_out : np.ndarray
  42. 3D representation of magnetization vector (Mx, My, Mz) sampled during readout
  43. Matrix dimensions : (total # readouts) x 3 x (# channels)
  44. times_out : np.ndarray
  45. Timing vector for all readout points
  46. """
  47. # 1. Read simulated data
  48. f = h5py.File(file,'r')
  49. signal = f['signal']
  50. channels = signal['channels']
  51. # 2. Initialize output array
  52. Nch = len(channels.keys())
  53. Nro_tot = channels[list(channels.keys())[0]].shape[0]
  54. M_vec_out = np.zeros((Nro_tot,3,Nch))
  55. Mxy_out = np.zeros((Nro_tot,Nch), dtype=complex)
  56. times_out = np.array(signal['times'])
  57. # 3. Read each channel and store in array
  58. for ch, key in enumerate(list(channels.keys())):
  59. one_ch_data = np.array(channels[key])
  60. M_vec_out[:,:,ch] = one_ch_data
  61. Mxy_out[:,ch] = one_ch_data[:,0] + 1j*one_ch_data[:,1]
  62. return Mxy_out, M_vec_out, times_out
  63. def recon_jemris_output(Mxy_out, dims):
  64. """Cartesian reconstruction of JEMRIS simulation output
  65. # (No EPI/interleave reordering)
  66. Inputs
  67. ------
  68. Mxy_out : np.ndarray
  69. Complex Nt x Nch array where Nt is the total number of data points and Nch is the number of channels
  70. dims : array_like
  71. [Nro], [Nro, Nline], or [Nro, Nline, Nslice]
  72. Returns
  73. -------
  74. kspace : np.ndarray
  75. Complex k-space matrix
  76. imspace : np.ndarray
  77. Complex image space matrix
  78. """
  79. Nt, Nch = Mxy_out.shape
  80. print(Nt)
  81. if Nt != np.prod(dims):
  82. raise ValueError("The dimensions provided do not match the total number of samples.")
  83. Nro = dims[0]
  84. Nline = 1
  85. Nslice = 1
  86. ld = len(dims)
  87. if ld >= 1:
  88. Nro = dims[0]
  89. if ld >= 2:
  90. Nline = dims[1]
  91. if ld == 3:
  92. Nslice = dims[2]
  93. if ld > 3:
  94. raise ValueError("dims should have at 1-3 numbers : Nro, (Nline), and (Nslice)")
  95. kspace = np.zeros((Nro, Nline, Nslice, Nch),dtype=complex)
  96. imspace = np.zeros((Nro, Nline, Nslice, Nch),dtype=complex)
  97. np.reshape(Mxy_out, (Nro, Nline, Nslice))
  98. for ch in range(Nch):
  99. kspace[:,:,:,ch] = np.reshape(Mxy_out[:, ch], (Nro, Nline, Nslice), order='F')
  100. for sl in range(Nslice):
  101. imspace[:,:,sl,ch] = np.fft.fftshift(np.fft.ifft2(kspace[:,:,sl,ch]))
  102. return kspace, imspace
  103. def save_recon_images(imspace, method='sum_squares'):
  104. """For now, this method combines channels and returns the image matrix
  105. (Future, for GUI use: add options to save as separate image files / mat / etc. in a directory)
  106. Inputs
  107. ------
  108. imspace : np.ndarray
  109. Complex image space. The last dimension must be # Channels.
  110. method : str, optional
  111. Method used for combining channels
  112. Either 'sum_squares' (default, sum of squares) or 'sum_abs' (sum of absolute values)
  113. Returns
  114. -------
  115. images : np.ndarray
  116. Real, channel_combined image matrix
  117. """
  118. if method == 'sum_squares':
  119. images = np.sum(np.square(np.absolute(imspace)),axis=-1)
  120. elif method == 'sum_abs':
  121. images = np.sum(np.absolute(imspace), axis=-1)
  122. else:
  123. raise ValueError("Method not recognized. Must be either sum_squares or sum_abs")
  124. return images
  125. if __name__ == '__main__':
  126. Mxy_out, M_vec_out, times_out = read_jemris_output('sim/test0405/signals.h5')
  127. kk, im = recon_jemris_output(Mxy_out, dims=[15,15])
  128. images = save_recon_images(im)
  129. plt.figure(1)
  130. plt.subplot(121)
  131. plt.imshow(np.absolute(kk[:,:,0,0]))
  132. plt.title("k-space")
  133. plt.gray()
  134. plt.subplot(122)
  135. print(images)
  136. plt.imshow(np.squeeze(images[:,:,0]))
  137. plt.title("Image space")
  138. plt.gray()
  139. plt.show()