compress_shape.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from types import SimpleNamespace
  2. import numpy as np
  3. def compress_shape(
  4. decompressed_shape: np.ndarray, force_compression: bool = False
  5. ) -> SimpleNamespace:
  6. """
  7. Compress a gradient or pulse shape waveform using a run-length compression scheme on the derivative. This strategy
  8. encodes constant and linear waveforms with very few samples. A structure is returned with the fields:
  9. - num_samples - the number of samples in the uncompressed waveform
  10. - data - containing the compressed waveform
  11. See also `pypulseq.decompress_shape.py`.
  12. Parameters
  13. ----------
  14. decompressed_shape : numpy.ndarray
  15. Decompressed shape.
  16. force_compression: bool, default=False
  17. Boolean flag to indicate if compression is forced.
  18. Returns
  19. -------
  20. compressed_shape : SimpleNamespace
  21. A `SimpleNamespace` object containing the number of samples and the compressed data.
  22. """
  23. if np.any(~np.isfinite(decompressed_shape)):
  24. raise ValueError("compress_shape() received infinite samples.")
  25. if (
  26. not force_compression and len(decompressed_shape) <= 4
  27. ): # Avoid compressing very short shapes
  28. compressed_shape = SimpleNamespace()
  29. compressed_shape.num_samples = len(decompressed_shape)
  30. compressed_shape.data = decompressed_shape
  31. return compressed_shape
  32. # Single precision floating point has ~7.25 decimal places
  33. quant_factor = 1e-7
  34. decompressed_shape_scaled = decompressed_shape / quant_factor
  35. datq = np.round(
  36. np.insert(np.diff(decompressed_shape_scaled), 0, decompressed_shape_scaled[0])
  37. )
  38. qerr = decompressed_shape_scaled - np.cumsum(datq)
  39. qcor = np.insert(np.diff(np.round(qerr)), 0, 0)
  40. datd = datq + qcor
  41. mask_changes = np.insert(np.asarray(np.diff(datd) != 0, dtype=np.int32), 0, 1)
  42. # Elements without repetitions
  43. vals = datd[mask_changes.nonzero()[0]] * quant_factor
  44. # Indices of changes
  45. k = np.append(mask_changes, 1).nonzero()[0]
  46. # Number of repetitions
  47. n = np.diff(k)
  48. n_extra = (n - 2).astype(np.float32) # Cast as float for nan assignment to work
  49. vals2 = np.copy(vals)
  50. vals2[n_extra < 0] = np.nan
  51. n_extra[n_extra < 0] = np.nan
  52. v = np.stack((vals, vals2, n_extra))
  53. v = v.T[np.isfinite(v).T] # Use transposes to match Matlab's Fortran indexing order
  54. v[abs(v) < 1e-10] = 0
  55. compressed_shape = SimpleNamespace()
  56. compressed_shape.num_samples = len(decompressed_shape)
  57. # Decide whether compression makes sense, otherwise store the original
  58. if force_compression or compressed_shape.num_samples > len(v):
  59. compressed_shape.data = v
  60. else:
  61. compressed_shape.data = decompressed_shape
  62. return compressed_shape