api.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. """
  2. seq_interp FastAPI service.
  3. Endpoints:
  4. GET /health - liveness probe
  5. POST /interpret/ - upload .seq file, run full pipeline, return task_id
  6. GET /status/ - status of all submitted tasks
  7. GET /result/{task_id} - full result: xml_text, post_json, metadata, waveforms
  8. """
  9. from __future__ import annotations
  10. import asyncio
  11. import json
  12. import os
  13. import shutil
  14. import uuid
  15. from typing import Any
  16. from fastapi import FastAPI, File, HTTPException, UploadFile
  17. from src.config import config
  18. from src.hardware.constraints import HardwareConstraints
  19. from src.interfaces.pulseq_adapter import PulseqLoader
  20. from src.core.synchronizer import Synchronizer
  21. from src.interfaces.xml_generator import XMLGenerator
  22. from src.interfaces.rf_exporter import RFExporter
  23. from src.interfaces.gradient_exporter import GradientExporter
  24. from src.interfaces.picoscope_exporter import PicoScopeExporter
  25. from src.interfaces.post_request_generator import PostRequestGenerator
  26. app = FastAPI(title="seq-interp", version="1.0.0")
  27. UPLOAD_DIR = config.get("upload_dir", "data/input")
  28. OUTPUT_DIR = config.get("output_dir", "data/output")
  29. os.makedirs(UPLOAD_DIR, exist_ok=True)
  30. os.makedirs(OUTPUT_DIR, exist_ok=True)
  31. # in-memory cache: task_id -> {"status": str, "result": dict | None}
  32. _tasks: dict[str, dict] = {}
  33. _MAX_WAVEFORM_POINTS = 8000 # downsample for JSON transport
  34. _STATUS_FILE = "_status.json"
  35. _RESULT_FILE = "_result.json"
  36. # -- disk persistence --------------------------------------------------------
  37. def _task_dir(task_id: str) -> str:
  38. return os.path.join(OUTPUT_DIR, task_id)
  39. def _persist_task(task_id: str, entry: dict) -> None:
  40. """Write task status (and result if present) to disk."""
  41. d = _task_dir(task_id)
  42. os.makedirs(d, exist_ok=True)
  43. with open(os.path.join(d, _STATUS_FILE), "w", encoding="utf-8") as f:
  44. json.dump({"status": entry["status"]}, f)
  45. if entry.get("result") is not None:
  46. with open(os.path.join(d, _RESULT_FILE), "w", encoding="utf-8") as f:
  47. json.dump(entry["result"], f, ensure_ascii=False)
  48. def _load_task_from_disk(task_id: str) -> dict | None:
  49. """Restore task entry from disk (used after container restart)."""
  50. d = _task_dir(task_id)
  51. status_path = os.path.join(d, _STATUS_FILE)
  52. result_path = os.path.join(d, _RESULT_FILE)
  53. if not os.path.exists(status_path):
  54. return None
  55. with open(status_path, encoding="utf-8") as f:
  56. status = json.load(f).get("status", "unknown")
  57. result = None
  58. if os.path.exists(result_path):
  59. with open(result_path, encoding="utf-8") as f:
  60. result = json.load(f)
  61. return {"status": status, "result": result}
  62. # -- helpers ----------------------------------------------------------------
  63. def _downsample(arr, max_pts: int) -> list:
  64. """Convert numpy array to list, downsampling if too long."""
  65. try:
  66. import numpy as np
  67. a = np.asarray(arr, dtype=float).flatten()
  68. if len(a) > max_pts:
  69. idx = np.linspace(0, len(a) - 1, max_pts, dtype=int)
  70. a = a[idx]
  71. return [float(x) for x in a]
  72. except Exception:
  73. return []
  74. def _extract_waveforms(seq_data: dict, sync_data: dict) -> dict:
  75. """Extract waveform arrays from seq_data for GUI display."""
  76. waveforms: dict[str, list] = {}
  77. for key in ("gx", "gy", "gz", "t_gx", "t_gy", "t_gz"):
  78. if key in seq_data:
  79. waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS)
  80. # RF
  81. for key in ("rf_amp", "rf_phase", "t_rf"):
  82. if key in seq_data:
  83. waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS)
  84. # Sync gate arrays
  85. for key in ("gate_adc", "gate_rf", "gate_tr_switch", "blocks_duration"):
  86. if key in sync_data:
  87. waveforms[key] = _downsample(sync_data[key], _MAX_WAVEFORM_POINTS)
  88. return waveforms
  89. async def _run_pipeline(file_path: str, task_id: str) -> None:
  90. """Run the full interpretation pipeline and store result in _tasks."""
  91. out_dir = os.path.join(OUTPUT_DIR, task_id)
  92. os.makedirs(out_dir, exist_ok=True)
  93. try:
  94. hw = HardwareConstraints(json_path=config.get("hw_config_path"))
  95. hw_cfg = config.hw_config
  96. loader = PulseqLoader(hw)
  97. seq_data = await asyncio.to_thread(loader.load, file_path)
  98. params = seq_data.get("params", {})
  99. sync = Synchronizer(hw)
  100. sync_data = await asyncio.to_thread(sync.process, seq_data["sequence"])
  101. xml_gen = XMLGenerator()
  102. xml_path = os.path.join(out_dir, "sync_v2.xml")
  103. adc_values, adc_starts = await asyncio.to_thread(
  104. xml_gen.generate, sync_data, xml_path, hw
  105. )
  106. with open(xml_path, encoding="utf-8") as fh:
  107. xml_text = fh.read()
  108. export_tasks = [
  109. asyncio.to_thread(RFExporter().export, seq_data, params, out_dir)
  110. ]
  111. if all(k in seq_data for k in ("gx", "gy", "gz")):
  112. export_tasks.append(
  113. asyncio.to_thread(GradientExporter().export, seq_data, params, out_dir)
  114. )
  115. iadc = hw_cfg.get("iadc", {})
  116. export_tasks.append(asyncio.to_thread(
  117. PicoScopeExporter().generate,
  118. adc_values, adc_starts, out_dir, hw,
  119. sampling_freq=iadc.get("srate", 8e6),
  120. num_channels=iadc.get("n_channels", 3),
  121. ))
  122. await asyncio.gather(*export_tasks)
  123. post_gen = PostRequestGenerator()
  124. post_payload = post_gen.build(
  125. seq_data=seq_data,
  126. adc_values=adc_values,
  127. sequence_path=file_path,
  128. output_dir=out_dir,
  129. hw_cfg=hw_cfg,
  130. rf_raster_time=params.get("rf_raster_time", 1e-6),
  131. )
  132. post_gen.write(post_payload, out_dir)
  133. blocks = seq_data.get("blocks", [])
  134. total_s = sum(sync_data.get("blocks_duration", []))
  135. adc_blocks = [b for b in blocks if b.get("has_adc")]
  136. result: dict[str, Any] = {
  137. "task_id": task_id,
  138. "status": "completed",
  139. "output_dir": out_dir,
  140. "xml_text": xml_text,
  141. "post_json": post_payload,
  142. "metadata": {
  143. "block_count": len(blocks),
  144. "sync_block_count": sync_data.get("number_of_blocks", 0),
  145. "adc_count": len(adc_blocks),
  146. "adc_windows": len(adc_values),
  147. "total_duration_ms": round(total_s * 1e3, 4),
  148. "rf_raster_us": params.get("rf_raster_time", 1e-6) * 1e6,
  149. "grad_raster_us": params.get("grad_raster_time", 1e-5) * 1e6,
  150. },
  151. "waveforms": _extract_waveforms(seq_data, sync_data),
  152. }
  153. entry = {"status": "completed", "result": result}
  154. _tasks[task_id] = entry
  155. _persist_task(task_id, entry)
  156. except Exception as exc:
  157. entry = {"status": f"failed: {exc}", "result": None}
  158. _tasks[task_id] = entry
  159. _persist_task(task_id, entry)
  160. # -- endpoints ----------------------------------------------------------------
  161. @app.get("/health")
  162. def health():
  163. return {"status": "ok"}
  164. @app.post("/interpret/")
  165. async def interpret_endpoint(file: UploadFile = File(...)):
  166. """Upload a .seq file and run the full interpretation pipeline."""
  167. task_id = str(uuid.uuid4())
  168. # Store the file under the task_id so parallel uploads never collide
  169. task_upload_dir = os.path.join(UPLOAD_DIR, task_id)
  170. os.makedirs(task_upload_dir, exist_ok=True)
  171. file_path = os.path.join(task_upload_dir, file.filename)
  172. with open(file_path, "wb") as buf:
  173. shutil.copyfileobj(file.file, buf)
  174. entry: dict = {"status": "processing", "result": None}
  175. _tasks[task_id] = entry
  176. _persist_task(task_id, entry)
  177. asyncio.create_task(_run_pipeline(file_path, task_id))
  178. return {"status": "accepted", "task_id": task_id,
  179. "message": f"Processing {file.filename}"}
  180. @app.get("/status/")
  181. def status_endpoint():
  182. """Return the status of all submitted tasks."""
  183. return {"tasks": {tid: v["status"] for tid, v in _tasks.items()}}
  184. @app.get("/result/{task_id}")
  185. def result_endpoint(task_id: str):
  186. """Return full interpretation result (xml_text, post_json, metadata, waveforms)."""
  187. entry = _tasks.get(task_id)
  188. if entry is None:
  189. # Container may have restarted — try restoring from disk
  190. entry = _load_task_from_disk(task_id)
  191. if entry is None:
  192. raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
  193. _tasks[task_id] = entry # restore to memory cache
  194. if entry["status"] == "processing":
  195. raise HTTPException(status_code=202, detail="Still processing")
  196. if entry["result"] is None:
  197. raise HTTPException(status_code=500, detail=entry["status"])
  198. return entry["result"]