api.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. """
  2. seq_interp FastAPI service.
  3. Endpoints:
  4. GET /health - liveness probe
  5. POST /interpret/ - upload .seq file, run full pipeline, return task_id
  6. GET /status/ - status of all submitted tasks
  7. GET /result/{task_id} - full result: xml_text, post_json, metadata, waveforms
  8. """
  9. from __future__ import annotations
  10. import asyncio
  11. import os
  12. import shutil
  13. from typing import Any
  14. from fastapi import FastAPI, File, HTTPException, UploadFile
  15. from src.config import config
  16. from src.hardware.constraints import HardwareConstraints
  17. from src.interfaces.pulseq_adapter import PulseqLoader
  18. from src.core.synchronizer import Synchronizer
  19. from src.interfaces.xml_generator import XMLGenerator
  20. from src.interfaces.rf_exporter import RFExporter
  21. from src.interfaces.gradient_exporter import GradientExporter
  22. from src.interfaces.picoscope_exporter import PicoScopeExporter
  23. from src.interfaces.post_request_generator import PostRequestGenerator
  24. app = FastAPI(title="seq-interp", version="1.0.0")
  25. UPLOAD_DIR = config.get("upload_dir", "data/input")
  26. OUTPUT_DIR = config.get("output_dir", "data/output")
  27. os.makedirs(UPLOAD_DIR, exist_ok=True)
  28. os.makedirs(OUTPUT_DIR, exist_ok=True)
  29. # task_id -> {"status": str, "result": dict | None}
  30. _tasks: dict[str, dict] = {}
  31. _MAX_WAVEFORM_POINTS = 8000 # downsample for JSON transport
  32. # -- helpers ----------------------------------------------------------------
  33. def _downsample(arr, max_pts: int) -> list:
  34. """Convert numpy array to list, downsampling if too long."""
  35. try:
  36. import numpy as np
  37. a = np.asarray(arr, dtype=float).flatten()
  38. if len(a) > max_pts:
  39. idx = np.linspace(0, len(a) - 1, max_pts, dtype=int)
  40. a = a[idx]
  41. return [float(x) for x in a]
  42. except Exception:
  43. return []
  44. def _extract_waveforms(seq_data: dict, sync_data: dict) -> dict:
  45. """Extract waveform arrays from seq_data for GUI display."""
  46. waveforms: dict[str, list] = {}
  47. for key in ("gx", "gy", "gz", "t_gx", "t_gy", "t_gz"):
  48. if key in seq_data:
  49. waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS)
  50. # RF
  51. for key in ("rf_amp", "rf_phase", "t_rf"):
  52. if key in seq_data:
  53. waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS)
  54. # Sync gate arrays
  55. for key in ("gate_adc", "gate_rf", "gate_tr_switch", "blocks_duration"):
  56. if key in sync_data:
  57. waveforms[key] = _downsample(sync_data[key], _MAX_WAVEFORM_POINTS)
  58. return waveforms
  59. async def _run_pipeline(file_path: str, task_id: str) -> None:
  60. """Run the full interpretation pipeline and store result in _tasks."""
  61. out_dir = os.path.join(OUTPUT_DIR, task_id)
  62. os.makedirs(out_dir, exist_ok=True)
  63. try:
  64. hw = HardwareConstraints(json_path=config.get("hw_config_path"))
  65. hw_cfg = config.hw_config
  66. loader = PulseqLoader(hw)
  67. seq_data = await asyncio.to_thread(loader.load, file_path)
  68. params = seq_data.get("params", {})
  69. sync = Synchronizer(hw)
  70. sync_data = await asyncio.to_thread(sync.process, seq_data["sequence"])
  71. xml_gen = XMLGenerator()
  72. xml_path = os.path.join(out_dir, "sync_v2.xml")
  73. adc_values, adc_starts = await asyncio.to_thread(
  74. xml_gen.generate, sync_data, xml_path, hw
  75. )
  76. with open(xml_path, encoding="utf-8") as fh:
  77. xml_text = fh.read()
  78. export_tasks = [
  79. asyncio.to_thread(RFExporter().export, seq_data, params, out_dir)
  80. ]
  81. if all(k in seq_data for k in ("gx", "gy", "gz")):
  82. export_tasks.append(
  83. asyncio.to_thread(GradientExporter().export, seq_data, params, out_dir)
  84. )
  85. iadc = hw_cfg.get("iadc", {})
  86. export_tasks.append(asyncio.to_thread(
  87. PicoScopeExporter().generate,
  88. adc_values, adc_starts, out_dir, hw,
  89. sampling_freq=iadc.get("srate", 8e6),
  90. num_channels=iadc.get("n_channels", 3),
  91. ))
  92. await asyncio.gather(*export_tasks)
  93. post_gen = PostRequestGenerator()
  94. post_payload = post_gen.build(
  95. seq_data=seq_data,
  96. adc_values=adc_values,
  97. sequence_path=file_path,
  98. output_dir=out_dir,
  99. hw_cfg=hw_cfg,
  100. rf_raster_time=params.get("rf_raster_time", 1e-6),
  101. )
  102. post_gen.write(post_payload, out_dir)
  103. blocks = seq_data.get("blocks", [])
  104. total_s = sum(sync_data.get("blocks_duration", []))
  105. adc_blocks = [b for b in blocks if b.get("has_adc")]
  106. result: dict[str, Any] = {
  107. "task_id": task_id,
  108. "status": "completed",
  109. "output_dir": out_dir,
  110. "xml_text": xml_text,
  111. "post_json": post_payload,
  112. "metadata": {
  113. "block_count": len(blocks),
  114. "sync_block_count": sync_data.get("number_of_blocks", 0),
  115. "adc_count": len(adc_blocks),
  116. "adc_windows": len(adc_values),
  117. "total_duration_ms": round(total_s * 1e3, 4),
  118. "rf_raster_us": params.get("rf_raster_time", 1e-6) * 1e6,
  119. "grad_raster_us": params.get("grad_raster_time", 1e-5) * 1e6,
  120. },
  121. "waveforms": _extract_waveforms(seq_data, sync_data),
  122. }
  123. _tasks[task_id] = {"status": "completed", "result": result}
  124. except Exception as exc:
  125. _tasks[task_id] = {"status": f"failed: {exc}", "result": None}
  126. # -- endpoints ----------------------------------------------------------------
  127. @app.get("/health")
  128. def health():
  129. return {"status": "ok"}
  130. @app.post("/interpret/")
  131. async def interpret_endpoint(file: UploadFile = File(...)):
  132. """Upload a .seq file and run the full interpretation pipeline."""
  133. file_path = os.path.join(UPLOAD_DIR, file.filename)
  134. with open(file_path, "wb") as buf:
  135. shutil.copyfileobj(file.file, buf)
  136. task_id = os.path.splitext(file.filename)[0]
  137. _tasks[task_id] = {"status": "processing", "result": None}
  138. asyncio.create_task(_run_pipeline(file_path, task_id))
  139. return {"status": "accepted", "task_id": task_id,
  140. "message": f"Processing {file.filename}"}
  141. @app.get("/status/")
  142. def status_endpoint():
  143. """Return the status of all submitted tasks."""
  144. return {"tasks": {tid: v["status"] for tid, v in _tasks.items()}}
  145. @app.get("/result/{task_id}")
  146. def result_endpoint(task_id: str):
  147. """Return full interpretation result (xml_text, post_json, metadata, waveforms)."""
  148. entry = _tasks.get(task_id)
  149. if entry is None:
  150. raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
  151. if entry["status"] == "processing":
  152. raise HTTPException(status_code=202, detail="Still processing")
  153. if entry["result"] is None:
  154. raise HTTPException(status_code=500, detail=entry["status"])
  155. return entry["result"]