scale_grad.py 991 B

1234567891011121314151617181920212223242526272829303132333435
  1. from copy import copy
  2. from types import SimpleNamespace
  3. def scale_grad(grad: SimpleNamespace, scale: float) -> SimpleNamespace:
  4. """
  5. Scales the gradient with the scalar.
  6. Parameters
  7. ----------
  8. grad : SimpleNamespace
  9. Gradient event to be scaled.
  10. scale : float
  11. Scaling factor.
  12. Returns
  13. -------
  14. grad : SimpleNamespace
  15. Scaled gradient.
  16. """
  17. # copy() to emulate pass-by-value; otherwise passed grad event is modified
  18. scaled_grad = copy(grad)
  19. if scaled_grad.type == "trap":
  20. scaled_grad.amplitude = scaled_grad.amplitude * scale
  21. scaled_grad.area = scaled_grad.area * scale
  22. scaled_grad.flat_area = scaled_grad.flat_area * scale
  23. else:
  24. scaled_grad.waveform = scaled_grad.waveform * scale
  25. scaled_grad.first = scaled_grad.first * scale
  26. scaled_grad.last = scaled_grad.last * scale
  27. if hasattr(scaled_grad, "id"):
  28. delattr(scaled_grad, "id")
  29. return scaled_grad