api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. """
  2. seq_interp FastAPI service.
  3. Endpoints:
  4. GET /health - liveness probe
  5. POST /interpret/ - upload .seq file, run full pipeline, return task_id
  6. GET /status/ - status of all submitted tasks
  7. GET /result/{task_id} - full result: xml_text, post_json, metadata, waveforms
  8. """
  9. from __future__ import annotations
  10. import asyncio
  11. import io
  12. import json
  13. import os
  14. import shutil
  15. import uuid
  16. import zipfile
  17. from typing import Any
  18. from fastapi import FastAPI, File, Form, HTTPException, UploadFile
  19. from fastapi.responses import Response
  20. from src.config import config
  21. from src.hardware.constraints import HardwareConstraints
  22. from src.interfaces.pulseq_adapter import PulseqLoader
  23. from src.core.synchronizer import Synchronizer
  24. from src.interfaces.xml_generator import XMLGenerator
  25. from src.interfaces.rf_exporter import RFExporter
  26. from src.interfaces.gradient_exporter import GradientExporter
  27. from src.interfaces.picoscope_exporter import PicoScopeExporter
  28. from src.interfaces.post_request_generator import PostRequestGenerator
  29. app = FastAPI(title="seq-interp", version="1.0.0")
  30. UPLOAD_DIR = config.get("upload_dir", "data/input")
  31. OUTPUT_DIR = config.get("output_dir", "data/output")
  32. os.makedirs(UPLOAD_DIR, exist_ok=True)
  33. os.makedirs(OUTPUT_DIR, exist_ok=True)
  34. # in-memory cache: task_id -> {"status": str, "result": dict | None}
  35. _tasks: dict[str, dict] = {}
  36. _MAX_WAVEFORM_POINTS = 8000 # downsample for JSON transport
  37. _STATUS_FILE = "_status.json"
  38. _RESULT_FILE = "_result.json"
  39. # -- disk persistence --------------------------------------------------------
  40. def _task_dir(task_id: str) -> str:
  41. return os.path.join(OUTPUT_DIR, task_id)
  42. def _persist_task(task_id: str, entry: dict) -> None:
  43. """Write task status (and result if present) to disk."""
  44. d = _task_dir(task_id)
  45. os.makedirs(d, exist_ok=True)
  46. with open(os.path.join(d, _STATUS_FILE), "w", encoding="utf-8") as f:
  47. json.dump({"status": entry["status"]}, f)
  48. if entry.get("result") is not None:
  49. with open(os.path.join(d, _RESULT_FILE), "w", encoding="utf-8") as f:
  50. json.dump(entry["result"], f, ensure_ascii=False)
  51. def _load_task_from_disk(task_id: str) -> dict | None:
  52. """Restore task entry from disk (used after container restart)."""
  53. d = _task_dir(task_id)
  54. status_path = os.path.join(d, _STATUS_FILE)
  55. result_path = os.path.join(d, _RESULT_FILE)
  56. if not os.path.exists(status_path):
  57. return None
  58. with open(status_path, encoding="utf-8") as f:
  59. status = json.load(f).get("status", "unknown")
  60. result = None
  61. if os.path.exists(result_path):
  62. with open(result_path, encoding="utf-8") as f:
  63. result = json.load(f)
  64. return {"status": status, "result": result}
  65. # -- helpers ----------------------------------------------------------------
  66. def _downsample(arr, max_pts: int) -> list:
  67. """Convert numpy array to list, downsampling if too long."""
  68. try:
  69. import numpy as np
  70. a = np.asarray(arr, dtype=float).flatten()
  71. if len(a) > max_pts:
  72. idx = np.linspace(0, len(a) - 1, max_pts, dtype=int)
  73. a = a[idx]
  74. return [float(x) for x in a]
  75. except Exception:
  76. return []
  77. def _extract_waveforms(seq_data: dict, sync_data: dict) -> dict:
  78. """Extract waveform arrays from seq_data for GUI display."""
  79. try:
  80. import numpy as _np
  81. except ImportError:
  82. _np = None
  83. waveforms: dict[str, list] = {}
  84. # Gradients
  85. for key in ("gx", "gy", "gz", "t_gx", "t_gy", "t_gz"):
  86. if key in seq_data:
  87. waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS)
  88. # RF — seq_data stores a complex array under "rf" / "t_rf".
  89. # Split into amplitude and phase for JSON transport (complex is not serialisable).
  90. if "rf" in seq_data and "t_rf" in seq_data and _np is not None:
  91. rf = _np.asarray(seq_data["rf"])
  92. waveforms["rf_amp"] = _downsample(_np.abs(rf), _MAX_WAVEFORM_POINTS)
  93. waveforms["rf_phase"] = _downsample(_np.angle(rf), _MAX_WAVEFORM_POINTS)
  94. waveforms["t_rf"] = _downsample(seq_data["t_rf"], _MAX_WAVEFORM_POINTS)
  95. # Sync gate arrays
  96. for key in ("gate_adc", "gate_rf", "gate_tr_switch", "blocks_duration"):
  97. if key in sync_data:
  98. waveforms[key] = _downsample(sync_data[key], _MAX_WAVEFORM_POINTS)
  99. return waveforms
  100. async def _run_pipeline(file_path: str, task_id: str,
  101. hw_overrides: dict | None = None) -> None:
  102. """Run the full interpretation pipeline and store result in _tasks."""
  103. out_dir = os.path.join(OUTPUT_DIR, task_id)
  104. os.makedirs(out_dir, exist_ok=True)
  105. try:
  106. hw = HardwareConstraints(json_path=config.get("hw_config_path"))
  107. # GUI-supplied hardware overrides (delay values, raster, and the
  108. # TR/RF/START enable flags) are applied on top of the service config,
  109. # so the GUI controls work without a local interpreter copy.
  110. for key, value in (hw_overrides or {}).items():
  111. setattr(hw, key, value)
  112. hw_cfg = config.hw_config
  113. loader = PulseqLoader(hw)
  114. seq_data = await asyncio.to_thread(loader.load, file_path)
  115. params = seq_data.get("params", {})
  116. sync = Synchronizer(hw)
  117. sync_data = await asyncio.to_thread(sync.process, seq_data["sequence"])
  118. xml_gen = XMLGenerator()
  119. xml_path = os.path.join(out_dir, "sync_v2.xml")
  120. adc_values, adc_starts = await asyncio.to_thread(
  121. xml_gen.generate, sync_data, xml_path, hw
  122. )
  123. with open(xml_path, encoding="utf-8") as fh:
  124. xml_text = fh.read()
  125. export_tasks = [
  126. asyncio.to_thread(RFExporter().export, seq_data, params, out_dir)
  127. ]
  128. if all(k in seq_data for k in ("gx", "gy", "gz")):
  129. export_tasks.append(
  130. asyncio.to_thread(GradientExporter().export, seq_data, params, out_dir)
  131. )
  132. iadc = hw_cfg.get("iadc", {})
  133. export_tasks.append(asyncio.to_thread(
  134. PicoScopeExporter().generate,
  135. adc_values, adc_starts, out_dir, hw,
  136. sampling_freq=iadc.get("srate", 8e6),
  137. num_channels=iadc.get("n_channels", 3),
  138. ))
  139. await asyncio.gather(*export_tasks)
  140. post_gen = PostRequestGenerator()
  141. post_payload = post_gen.build(
  142. seq_data=seq_data,
  143. adc_values=adc_values,
  144. sequence_path=file_path,
  145. output_dir=out_dir,
  146. hw_cfg=hw_cfg,
  147. rf_raster_time=params.get("rf_raster_time", 1e-6),
  148. )
  149. post_gen.write(post_payload, out_dir)
  150. blocks = seq_data.get("blocks", [])
  151. total_s = sum(sync_data.get("blocks_duration", []))
  152. adc_blocks = [b for b in blocks if b.get("has_adc")]
  153. # Serialize only the fields needed by the GUI block-row builder
  154. blocks_slim = [
  155. {"has_adc": bool(b.get("has_adc")),
  156. "type": list(b.get("type", []))}
  157. for b in blocks
  158. ]
  159. result: dict[str, Any] = {
  160. "task_id": task_id,
  161. "status": "completed",
  162. "output_dir": out_dir,
  163. "xml_text": xml_text,
  164. "post_json": post_payload,
  165. "metadata": {
  166. "block_count": len(blocks),
  167. "sync_block_count": sync_data.get("number_of_blocks", 0),
  168. "adc_count": len(adc_blocks),
  169. "adc_windows": len(adc_values),
  170. "total_duration_ms": round(total_s * 1e3, 4),
  171. "rf_raster_us": params.get("rf_raster_time", 1e-6) * 1e6,
  172. "grad_raster_us": params.get("grad_raster_time", 1e-5) * 1e6,
  173. },
  174. "blocks": blocks_slim,
  175. "waveforms": _extract_waveforms(seq_data, sync_data),
  176. }
  177. entry = {"status": "completed", "result": result}
  178. _tasks[task_id] = entry
  179. _persist_task(task_id, entry)
  180. except Exception as exc:
  181. entry = {"status": f"failed: {exc}", "result": None}
  182. _tasks[task_id] = entry
  183. _persist_task(task_id, entry)
  184. # -- endpoints ----------------------------------------------------------------
  185. @app.get("/health")
  186. def health():
  187. return {"status": "ok"}
  188. @app.post("/interpret/")
  189. async def interpret_endpoint(
  190. file: UploadFile = File(...),
  191. hw_overrides: str | None = Form(None),
  192. ):
  193. """
  194. Upload a .seq file and run the full interpretation pipeline.
  195. hw_overrides : optional JSON object string of HardwareConstraints attribute
  196. overrides (e.g. delay values and TR/RF/START enable flags),
  197. applied on top of the service-side hw config.
  198. """
  199. overrides: dict = {}
  200. if hw_overrides:
  201. try:
  202. overrides = json.loads(hw_overrides)
  203. if not isinstance(overrides, dict):
  204. raise ValueError("hw_overrides must be a JSON object")
  205. except (ValueError, json.JSONDecodeError) as exc:
  206. raise HTTPException(status_code=422,
  207. detail=f"Invalid hw_overrides: {exc}")
  208. task_id = str(uuid.uuid4())
  209. # Store the file under the task_id so parallel uploads never collide
  210. task_upload_dir = os.path.join(UPLOAD_DIR, task_id)
  211. os.makedirs(task_upload_dir, exist_ok=True)
  212. file_path = os.path.join(task_upload_dir, file.filename)
  213. with open(file_path, "wb") as buf:
  214. shutil.copyfileobj(file.file, buf)
  215. entry: dict = {"status": "processing", "result": None}
  216. _tasks[task_id] = entry
  217. _persist_task(task_id, entry)
  218. asyncio.create_task(_run_pipeline(file_path, task_id, overrides))
  219. return {"status": "accepted", "task_id": task_id,
  220. "message": f"Processing {file.filename}"}
  221. @app.get("/status/")
  222. def status_endpoint():
  223. """Return the status of all submitted tasks."""
  224. return {"tasks": {tid: v["status"] for tid, v in _tasks.items()}}
  225. @app.get("/artifacts/{task_id}")
  226. def artifacts_endpoint(task_id: str):
  227. """
  228. Download all output artifacts for a completed task as a zip
  229. (sync_v2.xml, rf_*.bin, gx/gy/gz.txt, picoscope_params.xml,
  230. post_request.json). Used by the GUI's Export action in pure-client mode.
  231. """
  232. task_dir = _task_dir(task_id)
  233. if not os.path.isdir(task_dir):
  234. raise HTTPException(status_code=404,
  235. detail=f"No artifacts for task '{task_id}'")
  236. buf = io.BytesIO()
  237. with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
  238. for name in os.listdir(task_dir):
  239. if name in (_STATUS_FILE, _RESULT_FILE):
  240. continue
  241. full = os.path.join(task_dir, name)
  242. if os.path.isfile(full):
  243. zf.write(full, arcname=name)
  244. buf.seek(0)
  245. return Response(
  246. content=buf.getvalue(),
  247. media_type="application/zip",
  248. headers={"Content-Disposition": f'attachment; filename="{task_id}.zip"'},
  249. )
  250. @app.get("/result/{task_id}")
  251. def result_endpoint(task_id: str):
  252. """Return full interpretation result (xml_text, post_json, metadata, waveforms)."""
  253. entry = _tasks.get(task_id)
  254. if entry is None:
  255. # Container may have restarted — try restoring from disk
  256. entry = _load_task_from_disk(task_id)
  257. if entry is None:
  258. raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
  259. _tasks[task_id] = entry # restore to memory cache
  260. if entry["status"] == "processing":
  261. raise HTTPException(status_code=202, detail="Still processing")
  262. if entry["result"] is None:
  263. raise HTTPException(status_code=500, detail=entry["status"])
  264. return entry["result"]