calc_grad_spectrum.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from typing import Tuple, List, Union
  2. import numpy as np
  3. from scipy.signal import spectrogram
  4. from matplotlib import pyplot as plt
  5. def calculate_gradient_spectrum(
  6. obj,
  7. max_frequency: float = 2000,
  8. window_width: float = 0.05,
  9. frequency_oversampling: float = 3,
  10. time_range: Union[List[float], None] = None,
  11. plot: bool = True,
  12. combine_mode: str = 'max',
  13. use_derivative: bool = False,
  14. acoustic_resonances: List[dict] = [],
  15. ) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray, np.ndarray]:
  16. """
  17. Calculates the gradient spectrum of the sequence. Returns a spectrogram
  18. for each gradient channel, as well as a root-sum-squares combined
  19. spectrogram.
  20. Works by splitting the sequence into windows that are 'window_width'
  21. long and calculating the fourier transform of each window. Windows
  22. overlap 50% with the previous and next window. When 'combine_mode' is
  23. not 'none', all windows are combined into one spectrogram.
  24. Parameters
  25. ----------
  26. max_frequency : float, optional
  27. Maximum frequency to include in spectrograms. The default is 2000.
  28. window_width : float, optional
  29. Window width (in seconds). The default is 0.05.
  30. frequency_oversampling : float, optional
  31. Oversampling in the frequency dimension, higher values make
  32. smoother spectrograms. The default is 3.
  33. time_range : List[float], optional
  34. Time range over which to calculate the spectrograms as a list of
  35. two timepoints (in seconds) (e.g. [1, 1.5])
  36. The default is None.
  37. plot : bool, optional
  38. Whether to plot the spectograms. The default is True.
  39. combine_mode : str, optional
  40. How to combine all windows into one spectrogram, options:
  41. 'max', 'mean', 'rss' (root-sum-of-squares), 'none' (no combination)
  42. The default is 'max'.
  43. use_derivative : bool, optional
  44. Whether the use the derivative of the gradient waveforms instead of the
  45. gradient waveforms for the gradient spectrum calculations. The default
  46. is False
  47. acoustic_resonances : List[dict], optional
  48. Acoustic resonances as a list of dictionaries with 'frequency' and
  49. 'bandwidth' elements. Only used when plot==True. The default is [].
  50. Returns
  51. -------
  52. spectrograms : List[np.ndarray]
  53. List of spectrograms per gradient channel.
  54. spectrogram_rss : np.ndarray
  55. Root-sum-of-squares combined spectrogram over all gradient channels.
  56. frequencies : np.ndarray
  57. Frequency axis of the spectrograms.
  58. times : np.ndarray
  59. Time axis of the spectrograms (only relevant when combine_mode == 'none').
  60. """
  61. dt = obj.system.grad_raster_time # time raster
  62. nwin = round(window_width / dt)
  63. nfft = round(frequency_oversampling*nwin)
  64. # Get gradients as piecewise-polynomials
  65. gw_pp = obj.get_gradients(time_range=time_range)
  66. ng = len(gw_pp)
  67. max_t = max(g.x[-1] for g in gw_pp if g is not None)
  68. # Determine sampling points
  69. if time_range == None:
  70. nt = int(np.ceil(max_t/dt))
  71. t = (np.arange(nt) + 0.5)*dt
  72. else:
  73. tmax = min(time_range[1], max_t) - max(time_range[0], 0)
  74. nt = int(np.ceil(tmax/dt))
  75. t = max(time_range[0], 0) + (np.arange(nt) + 0.5)*dt
  76. # Sample gradients
  77. gw = np.zeros((ng,t.shape[0]))
  78. for i in range(ng):
  79. if gw_pp[i] != None:
  80. gw[i] = gw_pp[i](t)
  81. if use_derivative:
  82. gw = np.diff(gw, axis=1)
  83. # Calculate spectrogram for each gradient channel
  84. spectrograms: List[np.ndarray] = []
  85. spectrogram_rss = 0
  86. for i in range(ng):
  87. # Use scipy to calculate the spectrograms
  88. freq, times, sxx = spectrogram(gw[i],
  89. fs=1/dt,
  90. mode='magnitude',
  91. nperseg=nwin,
  92. noverlap=nwin//2,
  93. nfft=nfft,
  94. detrend='constant',
  95. window=('tukey', 1))
  96. mask = freq<max_frequency
  97. # Accumulate spectrum for all gradient channels
  98. spectrogram_rss += sxx[mask]**2
  99. # Combine spectrogram over time axis
  100. if combine_mode == 'max':
  101. s = sxx[mask].max(axis=1)
  102. elif combine_mode == 'mean':
  103. s = sxx[mask].mean(axis=1)
  104. elif combine_mode == 'rss':
  105. s = np.sqrt((sxx[mask]**2).sum(axis=1))
  106. elif combine_mode == 'none':
  107. s = sxx[mask]
  108. else:
  109. raise ValueError(f'Unknown value for combine_mode: {combine_mode}, must be one of [max, mean, rss, none]')
  110. frequencies = freq[mask]
  111. spectrograms.append(s)
  112. # Root-sum-of-squares combined spectrogram for all gradient channels
  113. spectrogram_rss = np.sqrt(spectrogram_rss)
  114. if combine_mode == 'max':
  115. spectrogram_rss = spectrogram_rss.max(axis=1)
  116. elif combine_mode == 'mean':
  117. spectrogram_rss = spectrogram_rss.mean(axis=1)
  118. elif combine_mode == 'rss':
  119. spectrogram_rss = np.sqrt((spectrogram_rss**2).sum(axis=1))
  120. # Plot spectrograms and acoustic resonances if specified
  121. if plot:
  122. if combine_mode != 'none':
  123. plt.figure()
  124. plt.xlabel('Frequency (Hz)')
  125. # According to spectrogram documentation y unit is (Hz/m)^2 / Hz = Hz/m^2, is this meaningful?
  126. for s in spectrograms:
  127. plt.plot(frequencies, s)
  128. plt.plot(frequencies, spectrogram_rss)
  129. plt.legend(['x', 'y', 'z', 'rss'])
  130. for res in acoustic_resonances:
  131. plt.axvline(res['frequency'], color='k', linestyle='-')
  132. plt.axvline(res['frequency'] - res['bandwidth']/2, color='k', linestyle='--')
  133. plt.axvline(res['frequency'] + res['bandwidth']/2, color='k', linestyle='--')
  134. else:
  135. for s, c in zip(spectrograms, ['X', 'Y', 'Z']):
  136. plt.figure()
  137. plt.title(f'Spectrum {c}')
  138. plt.xlabel('Time (s)')
  139. plt.ylabel('Frequency (Hz)')
  140. plt.imshow(abs(s[::-1]), extent=(times[0], times[-1], frequencies[0], frequencies[-1]),
  141. aspect=(times[-1]-times[0])/(frequencies[-1]-frequencies[0]))
  142. for res in acoustic_resonances:
  143. plt.axhline(res['frequency'], color='r', linestyle='-')
  144. plt.axhline(res['frequency'] - res['bandwidth']/2, color='r', linestyle='--')
  145. plt.axhline(res['frequency'] + res['bandwidth']/2, color='r', linestyle='--')
  146. plt.figure()
  147. plt.title('Total spectrum')
  148. plt.xlabel('Time (s)')
  149. plt.ylabel('Frequency (Hz)')
  150. plt.imshow(abs(spectrogram_rss[::-1]), extent=(times[0], times[-1], frequencies[0], frequencies[-1]),
  151. aspect=(times[-1]-times[0])/(frequencies[-1]-frequencies[0]))
  152. for res in acoustic_resonances:
  153. plt.axhline(res['frequency'], color='r', linestyle='-')
  154. plt.axhline(res['frequency'] - res['bandwidth']/2, color='r', linestyle='--')
  155. plt.axhline(res['frequency'] + res['bandwidth']/2, color='r', linestyle='--')
  156. return spectrograms, spectrogram_rss, frequencies, times