| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- """
- seq_interp FastAPI service.
- Endpoints:
- GET /health — liveness probe
- POST /interpret/ — upload .seq file, run full pipeline, return task_id
- GET /status/ — status of all submitted tasks
- GET /result/{task_id} — full result: xml_text, post_json, metadata, waveforms
- """
- from __future__ import annotations
- import asyncio
- import os
- import shutil
- from typing import Any
- from fastapi import FastAPI, File, HTTPException, UploadFile
- from src.config import config
- from src.hardware.constraints import HardwareConstraints
- from src.interfaces.pulseq_adapter import PulseqLoader
- from src.core.synchronizer import Synchronizer
- from src.interfaces.xml_generator import XMLGenerator
- from src.interfaces.rf_exporter import RFExporter
- from src.interfaces.gradient_exporter import GradientExporter
- from src.interfaces.picoscope_exporter import PicoScopeExporter
- from src.interfaces.post_request_generator import PostRequestGenerator
- app = FastAPI(title="seq-interp", version="1.0.0")
- UPLOAD_DIR = config.get("upload_dir", "data/input")
- OUTPUT_DIR = config.get("output_dir", "data/output")
- os.makedirs(UPLOAD_DIR, exist_ok=True)
- os.makedirs(OUTPUT_DIR, exist_ok=True)
- # task_id → {"status": str, "result": dict | None}
- _tasks: dict[str, dict] = {}
- _MAX_WAVEFORM_POINTS = 8000 # downsample for JSON transport
- # ── helpers ────────────────────────────────────────────────────────────────
- def _downsample(arr, max_pts: int) -> list:
- """Convert numpy array to list, downsampling if too long."""
- try:
- import numpy as np
- a = np.asarray(arr, dtype=float).flatten()
- if len(a) > max_pts:
- idx = np.linspace(0, len(a) - 1, max_pts, dtype=int)
- a = a[idx]
- return [float(x) for x in a]
- except Exception:
- return []
- def _extract_waveforms(seq_data: dict, sync_data: dict) -> dict:
- """Extract waveform arrays from seq_data for GUI display."""
- waveforms: dict[str, list] = {}
- for key in ("gx", "gy", "gz", "t_gx", "t_gy", "t_gz"):
- if key in seq_data:
- waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS)
- # RF
- for key in ("rf_amp", "rf_phase", "t_rf"):
- if key in seq_data:
- waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS)
- # Sync gate arrays
- for key in ("gate_adc", "gate_rf", "gate_tr_switch", "blocks_duration"):
- if key in sync_data:
- waveforms[key] = _downsample(sync_data[key], _MAX_WAVEFORM_POINTS)
- return waveforms
- async def _run_pipeline(file_path: str, task_id: str) -> None:
- """Run the full interpretation pipeline and store result in _tasks."""
- out_dir = os.path.join(OUTPUT_DIR, task_id)
- os.makedirs(out_dir, exist_ok=True)
- try:
- hw = HardwareConstraints(json_path=config.get("hw_config_path"))
- hw_cfg = config.hw_config
- loader = PulseqLoader(hw)
- seq_data = await asyncio.to_thread(loader.load, file_path)
- params = seq_data.get("params", {})
- sync = Synchronizer(hw)
- sync_data = await asyncio.to_thread(sync.process, seq_data["sequence"])
- xml_gen = XMLGenerator()
- xml_path = os.path.join(out_dir, "sync_v2.xml")
- adc_values, adc_starts = await asyncio.to_thread(
- xml_gen.generate, sync_data, xml_path, hw
- )
- with open(xml_path, encoding="utf-8") as fh:
- xml_text = fh.read()
- export_tasks = [
- asyncio.to_thread(RFExporter().export, seq_data, params, out_dir)
- ]
- if all(k in seq_data for k in ("gx", "gy", "gz")):
- export_tasks.append(
- asyncio.to_thread(GradientExporter().export, seq_data, params, out_dir)
- )
- iadc = hw_cfg.get("iadc", {})
- export_tasks.append(asyncio.to_thread(
- PicoScopeExporter().generate,
- adc_values, adc_starts, out_dir, hw,
- sampling_freq=iadc.get("srate", 8e6),
- num_channels=iadc.get("n_channels", 3),
- ))
- await asyncio.gather(*export_tasks)
- post_gen = PostRequestGenerator()
- post_payload = post_gen.build(
- seq_data=seq_data,
- adc_values=adc_values,
- sequence_path=file_path,
- output_dir=out_dir,
- hw_cfg=hw_cfg,
- rf_raster_time=params.get("rf_raster_time", 1e-6),
- )
- post_gen.write(post_payload, out_dir)
- blocks = seq_data.get("blocks", [])
- total_s = sum(sync_data.get("blocks_duration", []))
- adc_blocks = [b for b in blocks if b.get("has_adc")]
- result: dict[str, Any] = {
- "task_id": task_id,
- "status": "completed",
- "output_dir": out_dir,
- "xml_text": xml_text,
- "post_json": post_payload,
- "metadata": {
- "block_count": len(blocks),
- "sync_block_count": sync_data.get("number_of_blocks", 0),
- "adc_count": len(adc_blocks),
- "adc_windows": len(adc_values),
- "total_duration_ms": round(total_s * 1e3, 4),
- "rf_raster_us": params.get("rf_raster_time", 1e-6) * 1e6,
- "grad_raster_us": params.get("grad_raster_time", 1e-5) * 1e6,
- },
- "waveforms": _extract_waveforms(seq_data, sync_data),
- }
- _tasks[task_id] = {"status": "completed", "result": result}
- except Exception as exc:
- _tasks[task_id] = {"status": f"failed: {exc}", "result": None}
- # ── endpoints ────────────────────────────────────────────────────────────────
- @app.get("/health")
- def health():
- return {"status": "ok"}
- @app.post("/interpret/")
- async def interpret_endpoint(file: UploadFile = File(...)):
- """Upload a .seq file and run the full interpretation pipeline."""
- file_path = os.path.join(UPLOAD_DIR, file.filename)
- with open(file_path, "wb") as buf:
- shutil.copyfileobj(file.file, buf)
- task_id = os.path.splitext(file.filename)[0]
- _tasks[task_id] = {"status": "processing", "result": None}
- asyncio.create_task(_run_pipeline(file_path, task_id))
- return {"status": "accepted", "task_id": task_id,
- "message": f"Processing {file.filename}"}
- @app.get("/status/")
- def status_endpoint():
- """Return the status of all submitted tasks."""
- return {"tasks": {tid: v["status"] for tid, v in _tasks.items()}}
- @app.get("/result/{task_id}")
- def result_endpoint(task_id: str):
- """Return full interpretation result (xml_text, post_json, metadata, waveforms)."""
- entry = _tasks.get(task_id)
- if entry is None:
- raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
- if entry["status"] == "processing":
- raise HTTPException(status_code=202, detail="Still processing")
- if entry["result"] is None:
- raise HTTPException(status_code=500, detail=entry["status"])
- return entry["result"]
|