api.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """
  2. FastAPI application for remote MRI sequence interpretation.
  3. Endpoints:
  4. POST /interpret/ — upload a .seq file, run full pipeline, return task_id
  5. GET /status/ — return status of all processed tasks
  6. """
  7. import asyncio
  8. import os
  9. import shutil
  10. from fastapi import FastAPI, File, UploadFile
  11. from src.config import config
  12. from src.hardware.constraints import HardwareConstraints
  13. from src.interfaces.pulseq_adapter import PulseqLoader
  14. from src.core.synchronizer import Synchronizer
  15. from src.interfaces.xml_generator import XMLGenerator
  16. from src.interfaces.rf_exporter import RFExporter
  17. from src.interfaces.gradient_exporter import GradientExporter
  18. from src.interfaces.picoscope_exporter import PicoScopeExporter
  19. from src.interfaces.post_request_generator import PostRequestGenerator
  20. app = FastAPI(title="LF-MRI Sequence Interpreter API")
  21. UPLOAD_DIR = config.get("upload_dir", "uploads")
  22. OUTPUT_DIR = config.get("output_dir", "output")
  23. os.makedirs(UPLOAD_DIR, exist_ok=True)
  24. os.makedirs(OUTPUT_DIR, exist_ok=True)
  25. _tasks: dict[str, str] = {}
  26. async def _run_pipeline(seq_path: str, task_id: str) -> None:
  27. try:
  28. hw = HardwareConstraints()
  29. loader = PulseqLoader(hw)
  30. seq_data = await asyncio.to_thread(loader.load, seq_path)
  31. synchronizer = Synchronizer(hw)
  32. sync_data = await asyncio.to_thread(synchronizer.process, seq_data["sequence"])
  33. out_dir = os.path.join(OUTPUT_DIR, task_id)
  34. os.makedirs(out_dir, exist_ok=True)
  35. hw_cfg = config.hw_config
  36. xml_gen = XMLGenerator()
  37. adc_values, adc_starts = await asyncio.to_thread(
  38. xml_gen.generate, sync_data, os.path.join(out_dir, "sync_v2.xml"), hw
  39. )
  40. tasks = [asyncio.to_thread(RFExporter().export, seq_data, seq_data.get("params", {}), out_dir)]
  41. if all(k in seq_data for k in ["gx", "gy", "gz"]):
  42. tasks.append(asyncio.to_thread(
  43. GradientExporter().export, seq_data, seq_data.get("params", {}), out_dir
  44. ))
  45. iadc = hw_cfg.get("iadc", {})
  46. tasks.append(asyncio.to_thread(
  47. PicoScopeExporter().generate,
  48. adc_values, adc_starts, out_dir, hw,
  49. sampling_freq=iadc.get("srate", 8e6),
  50. num_channels=iadc.get("n_channels", 3),
  51. ))
  52. await asyncio.gather(*tasks)
  53. post_gen = PostRequestGenerator()
  54. post_payload = post_gen.build(
  55. seq_data=seq_data,
  56. adc_values=adc_values,
  57. sequence_path=seq_path,
  58. output_dir=out_dir,
  59. hw_cfg=hw_cfg,
  60. rf_raster_time=seq_data.get("params", {}).get("rf_raster_time", 1e-6),
  61. )
  62. post_gen.write(post_payload, out_dir)
  63. _tasks[task_id] = f"Completed → {out_dir}"
  64. except Exception as exc:
  65. _tasks[task_id] = f"Failed: {exc}"
  66. @app.post("/interpret/")
  67. async def interpret_endpoint(file: UploadFile = File(...)):
  68. """Upload a .seq file and run the full interpretation pipeline."""
  69. file_path = os.path.join(UPLOAD_DIR, file.filename)
  70. with open(file_path, "wb") as buf:
  71. shutil.copyfileobj(file.file, buf)
  72. task_id = os.path.splitext(file.filename)[0]
  73. _tasks[task_id] = "Processing"
  74. asyncio.create_task(_run_pipeline(file_path, task_id))
  75. return {"status": "accepted", "task_id": task_id,
  76. "message": f"Processing {file.filename}"}
  77. @app.get("/status/")
  78. async def status_endpoint():
  79. """Return the status of all submitted tasks."""
  80. return {"tasks": _tasks}