api.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. """
  2. seq_interp FastAPI service.
  3. Endpoints:
  4. GET /health - liveness probe
  5. POST /interpret/ - upload .seq file, run full pipeline, return task_id
  6. GET /status/ - status of all submitted tasks
  7. GET /result/{task_id} - full result: xml_text, post_json, metadata, waveforms
  8. """
  9. from __future__ import annotations
  10. import asyncio
  11. import json
  12. import os
  13. import shutil
  14. import uuid
  15. from typing import Any
  16. from fastapi import FastAPI, File, HTTPException, UploadFile
  17. from src.config import config
  18. from src.hardware.constraints import HardwareConstraints
  19. from src.interfaces.pulseq_adapter import PulseqLoader
  20. from src.core.synchronizer import Synchronizer
  21. from src.interfaces.xml_generator import XMLGenerator
  22. from src.interfaces.rf_exporter import RFExporter
  23. from src.interfaces.gradient_exporter import GradientExporter
  24. from src.interfaces.picoscope_exporter import PicoScopeExporter
  25. from src.interfaces.post_request_generator import PostRequestGenerator
  26. app = FastAPI(title="seq-interp", version="1.0.0")
  27. UPLOAD_DIR = config.get("upload_dir", "data/input")
  28. OUTPUT_DIR = config.get("output_dir", "data/output")
  29. os.makedirs(UPLOAD_DIR, exist_ok=True)
  30. os.makedirs(OUTPUT_DIR, exist_ok=True)
  31. # in-memory cache: task_id -> {"status": str, "result": dict | None}
  32. _tasks: dict[str, dict] = {}
  33. _MAX_WAVEFORM_POINTS = 8000 # downsample for JSON transport
  34. _STATUS_FILE = "_status.json"
  35. _RESULT_FILE = "_result.json"
  36. # -- disk persistence --------------------------------------------------------
  37. def _task_dir(task_id: str) -> str:
  38. return os.path.join(OUTPUT_DIR, task_id)
  39. def _persist_task(task_id: str, entry: dict) -> None:
  40. """Write task status (and result if present) to disk."""
  41. d = _task_dir(task_id)
  42. os.makedirs(d, exist_ok=True)
  43. with open(os.path.join(d, _STATUS_FILE), "w", encoding="utf-8") as f:
  44. json.dump({"status": entry["status"]}, f)
  45. if entry.get("result") is not None:
  46. with open(os.path.join(d, _RESULT_FILE), "w", encoding="utf-8") as f:
  47. json.dump(entry["result"], f, ensure_ascii=False)
  48. def _load_task_from_disk(task_id: str) -> dict | None:
  49. """Restore task entry from disk (used after container restart)."""
  50. d = _task_dir(task_id)
  51. status_path = os.path.join(d, _STATUS_FILE)
  52. result_path = os.path.join(d, _RESULT_FILE)
  53. if not os.path.exists(status_path):
  54. return None
  55. with open(status_path, encoding="utf-8") as f:
  56. status = json.load(f).get("status", "unknown")
  57. result = None
  58. if os.path.exists(result_path):
  59. with open(result_path, encoding="utf-8") as f:
  60. result = json.load(f)
  61. return {"status": status, "result": result}
  62. # -- helpers ----------------------------------------------------------------
  63. def _downsample(arr, max_pts: int) -> list:
  64. """Convert numpy array to list, downsampling if too long."""
  65. try:
  66. import numpy as np
  67. a = np.asarray(arr, dtype=float).flatten()
  68. if len(a) > max_pts:
  69. idx = np.linspace(0, len(a) - 1, max_pts, dtype=int)
  70. a = a[idx]
  71. return [float(x) for x in a]
  72. except Exception:
  73. return []
  74. def _extract_waveforms(seq_data: dict, sync_data: dict) -> dict:
  75. """Extract waveform arrays from seq_data for GUI display."""
  76. try:
  77. import numpy as _np
  78. except ImportError:
  79. _np = None
  80. waveforms: dict[str, list] = {}
  81. # Gradients
  82. for key in ("gx", "gy", "gz", "t_gx", "t_gy", "t_gz"):
  83. if key in seq_data:
  84. waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS)
  85. # RF — seq_data stores a complex array under "rf" / "t_rf".
  86. # Split into amplitude and phase for JSON transport (complex is not serialisable).
  87. if "rf" in seq_data and "t_rf" in seq_data and _np is not None:
  88. rf = _np.asarray(seq_data["rf"])
  89. waveforms["rf_amp"] = _downsample(_np.abs(rf), _MAX_WAVEFORM_POINTS)
  90. waveforms["rf_phase"] = _downsample(_np.angle(rf), _MAX_WAVEFORM_POINTS)
  91. waveforms["t_rf"] = _downsample(seq_data["t_rf"], _MAX_WAVEFORM_POINTS)
  92. # Sync gate arrays
  93. for key in ("gate_adc", "gate_rf", "gate_tr_switch", "blocks_duration"):
  94. if key in sync_data:
  95. waveforms[key] = _downsample(sync_data[key], _MAX_WAVEFORM_POINTS)
  96. return waveforms
  97. async def _run_pipeline(file_path: str, task_id: str) -> None:
  98. """Run the full interpretation pipeline and store result in _tasks."""
  99. out_dir = os.path.join(OUTPUT_DIR, task_id)
  100. os.makedirs(out_dir, exist_ok=True)
  101. try:
  102. hw = HardwareConstraints(json_path=config.get("hw_config_path"))
  103. hw_cfg = config.hw_config
  104. loader = PulseqLoader(hw)
  105. seq_data = await asyncio.to_thread(loader.load, file_path)
  106. params = seq_data.get("params", {})
  107. sync = Synchronizer(hw)
  108. sync_data = await asyncio.to_thread(sync.process, seq_data["sequence"])
  109. xml_gen = XMLGenerator()
  110. xml_path = os.path.join(out_dir, "sync_v2.xml")
  111. adc_values, adc_starts = await asyncio.to_thread(
  112. xml_gen.generate, sync_data, xml_path, hw
  113. )
  114. with open(xml_path, encoding="utf-8") as fh:
  115. xml_text = fh.read()
  116. export_tasks = [
  117. asyncio.to_thread(RFExporter().export, seq_data, params, out_dir)
  118. ]
  119. if all(k in seq_data for k in ("gx", "gy", "gz")):
  120. export_tasks.append(
  121. asyncio.to_thread(GradientExporter().export, seq_data, params, out_dir)
  122. )
  123. iadc = hw_cfg.get("iadc", {})
  124. export_tasks.append(asyncio.to_thread(
  125. PicoScopeExporter().generate,
  126. adc_values, adc_starts, out_dir, hw,
  127. sampling_freq=iadc.get("srate", 8e6),
  128. num_channels=iadc.get("n_channels", 3),
  129. ))
  130. await asyncio.gather(*export_tasks)
  131. post_gen = PostRequestGenerator()
  132. post_payload = post_gen.build(
  133. seq_data=seq_data,
  134. adc_values=adc_values,
  135. sequence_path=file_path,
  136. output_dir=out_dir,
  137. hw_cfg=hw_cfg,
  138. rf_raster_time=params.get("rf_raster_time", 1e-6),
  139. )
  140. post_gen.write(post_payload, out_dir)
  141. blocks = seq_data.get("blocks", [])
  142. total_s = sum(sync_data.get("blocks_duration", []))
  143. adc_blocks = [b for b in blocks if b.get("has_adc")]
  144. result: dict[str, Any] = {
  145. "task_id": task_id,
  146. "status": "completed",
  147. "output_dir": out_dir,
  148. "xml_text": xml_text,
  149. "post_json": post_payload,
  150. "metadata": {
  151. "block_count": len(blocks),
  152. "sync_block_count": sync_data.get("number_of_blocks", 0),
  153. "adc_count": len(adc_blocks),
  154. "adc_windows": len(adc_values),
  155. "total_duration_ms": round(total_s * 1e3, 4),
  156. "rf_raster_us": params.get("rf_raster_time", 1e-6) * 1e6,
  157. "grad_raster_us": params.get("grad_raster_time", 1e-5) * 1e6,
  158. },
  159. "waveforms": _extract_waveforms(seq_data, sync_data),
  160. }
  161. entry = {"status": "completed", "result": result}
  162. _tasks[task_id] = entry
  163. _persist_task(task_id, entry)
  164. except Exception as exc:
  165. entry = {"status": f"failed: {exc}", "result": None}
  166. _tasks[task_id] = entry
  167. _persist_task(task_id, entry)
  168. # -- endpoints ----------------------------------------------------------------
  169. @app.get("/health")
  170. def health():
  171. return {"status": "ok"}
  172. @app.post("/interpret/")
  173. async def interpret_endpoint(file: UploadFile = File(...)):
  174. """Upload a .seq file and run the full interpretation pipeline."""
  175. task_id = str(uuid.uuid4())
  176. # Store the file under the task_id so parallel uploads never collide
  177. task_upload_dir = os.path.join(UPLOAD_DIR, task_id)
  178. os.makedirs(task_upload_dir, exist_ok=True)
  179. file_path = os.path.join(task_upload_dir, file.filename)
  180. with open(file_path, "wb") as buf:
  181. shutil.copyfileobj(file.file, buf)
  182. entry: dict = {"status": "processing", "result": None}
  183. _tasks[task_id] = entry
  184. _persist_task(task_id, entry)
  185. asyncio.create_task(_run_pipeline(file_path, task_id))
  186. return {"status": "accepted", "task_id": task_id,
  187. "message": f"Processing {file.filename}"}
  188. @app.get("/status/")
  189. def status_endpoint():
  190. """Return the status of all submitted tasks."""
  191. return {"tasks": {tid: v["status"] for tid, v in _tasks.items()}}
  192. @app.get("/result/{task_id}")
  193. def result_endpoint(task_id: str):
  194. """Return full interpretation result (xml_text, post_json, metadata, waveforms)."""
  195. entry = _tasks.get(task_id)
  196. if entry is None:
  197. # Container may have restarted — try restoring from disk
  198. entry = _load_task_from_disk(task_id)
  199. if entry is None:
  200. raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
  201. _tasks[task_id] = entry # restore to memory cache
  202. if entry["status"] == "processing":
  203. raise HTTPException(status_code=202, detail="Still processing")
  204. if entry["result"] is None:
  205. raise HTTPException(status_code=500, detail=entry["status"])
  206. return entry["result"]