""" seq_interp FastAPI service. Endpoints: GET /health - liveness probe POST /interpret/ - upload .seq file, run full pipeline, return task_id GET /status/ - status of all submitted tasks GET /result/{task_id} - full result: xml_text, post_json, metadata, waveforms """ from __future__ import annotations import asyncio import os import shutil from typing import Any from fastapi import FastAPI, File, HTTPException, UploadFile from src.config import config from src.hardware.constraints import HardwareConstraints from src.interfaces.pulseq_adapter import PulseqLoader from src.core.synchronizer import Synchronizer from src.interfaces.xml_generator import XMLGenerator from src.interfaces.rf_exporter import RFExporter from src.interfaces.gradient_exporter import GradientExporter from src.interfaces.picoscope_exporter import PicoScopeExporter from src.interfaces.post_request_generator import PostRequestGenerator app = FastAPI(title="seq-interp", version="1.0.0") UPLOAD_DIR = config.get("upload_dir", "data/input") OUTPUT_DIR = config.get("output_dir", "data/output") os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # task_id -> {"status": str, "result": dict | None} _tasks: dict[str, dict] = {} _MAX_WAVEFORM_POINTS = 8000 # downsample for JSON transport # -- helpers ---------------------------------------------------------------- def _downsample(arr, max_pts: int) -> list: """Convert numpy array to list, downsampling if too long.""" try: import numpy as np a = np.asarray(arr, dtype=float).flatten() if len(a) > max_pts: idx = np.linspace(0, len(a) - 1, max_pts, dtype=int) a = a[idx] return [float(x) for x in a] except Exception: return [] def _extract_waveforms(seq_data: dict, sync_data: dict) -> dict: """Extract waveform arrays from seq_data for GUI display.""" waveforms: dict[str, list] = {} for key in ("gx", "gy", "gz", "t_gx", "t_gy", "t_gz"): if key in seq_data: waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS) # RF for key in ("rf_amp", "rf_phase", "t_rf"): if key in seq_data: waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS) # Sync gate arrays for key in ("gate_adc", "gate_rf", "gate_tr_switch", "blocks_duration"): if key in sync_data: waveforms[key] = _downsample(sync_data[key], _MAX_WAVEFORM_POINTS) return waveforms async def _run_pipeline(file_path: str, task_id: str) -> None: """Run the full interpretation pipeline and store result in _tasks.""" out_dir = os.path.join(OUTPUT_DIR, task_id) os.makedirs(out_dir, exist_ok=True) try: hw = HardwareConstraints(json_path=config.get("hw_config_path")) hw_cfg = config.hw_config loader = PulseqLoader(hw) seq_data = await asyncio.to_thread(loader.load, file_path) params = seq_data.get("params", {}) sync = Synchronizer(hw) sync_data = await asyncio.to_thread(sync.process, seq_data["sequence"]) xml_gen = XMLGenerator() xml_path = os.path.join(out_dir, "sync_v2.xml") adc_values, adc_starts = await asyncio.to_thread( xml_gen.generate, sync_data, xml_path, hw ) with open(xml_path, encoding="utf-8") as fh: xml_text = fh.read() export_tasks = [ asyncio.to_thread(RFExporter().export, seq_data, params, out_dir) ] if all(k in seq_data for k in ("gx", "gy", "gz")): export_tasks.append( asyncio.to_thread(GradientExporter().export, seq_data, params, out_dir) ) iadc = hw_cfg.get("iadc", {}) export_tasks.append(asyncio.to_thread( PicoScopeExporter().generate, adc_values, adc_starts, out_dir, hw, sampling_freq=iadc.get("srate", 8e6), num_channels=iadc.get("n_channels", 3), )) await asyncio.gather(*export_tasks) post_gen = PostRequestGenerator() post_payload = post_gen.build( seq_data=seq_data, adc_values=adc_values, sequence_path=file_path, output_dir=out_dir, hw_cfg=hw_cfg, rf_raster_time=params.get("rf_raster_time", 1e-6), ) post_gen.write(post_payload, out_dir) blocks = seq_data.get("blocks", []) total_s = sum(sync_data.get("blocks_duration", [])) adc_blocks = [b for b in blocks if b.get("has_adc")] result: dict[str, Any] = { "task_id": task_id, "status": "completed", "output_dir": out_dir, "xml_text": xml_text, "post_json": post_payload, "metadata": { "block_count": len(blocks), "sync_block_count": sync_data.get("number_of_blocks", 0), "adc_count": len(adc_blocks), "adc_windows": len(adc_values), "total_duration_ms": round(total_s * 1e3, 4), "rf_raster_us": params.get("rf_raster_time", 1e-6) * 1e6, "grad_raster_us": params.get("grad_raster_time", 1e-5) * 1e6, }, "waveforms": _extract_waveforms(seq_data, sync_data), } _tasks[task_id] = {"status": "completed", "result": result} except Exception as exc: _tasks[task_id] = {"status": f"failed: {exc}", "result": None} # -- endpoints ---------------------------------------------------------------- @app.get("/health") def health(): return {"status": "ok"} @app.post("/interpret/") async def interpret_endpoint(file: UploadFile = File(...)): """Upload a .seq file and run the full interpretation pipeline.""" file_path = os.path.join(UPLOAD_DIR, file.filename) with open(file_path, "wb") as buf: shutil.copyfileobj(file.file, buf) task_id = os.path.splitext(file.filename)[0] _tasks[task_id] = {"status": "processing", "result": None} asyncio.create_task(_run_pipeline(file_path, task_id)) return {"status": "accepted", "task_id": task_id, "message": f"Processing {file.filename}"} @app.get("/status/") def status_endpoint(): """Return the status of all submitted tasks.""" return {"tasks": {tid: v["status"] for tid, v in _tasks.items()}} @app.get("/result/{task_id}") def result_endpoint(task_id: str): """Return full interpretation result (xml_text, post_json, metadata, waveforms).""" entry = _tasks.get(task_id) if entry is None: raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") if entry["status"] == "processing": raise HTTPException(status_code=202, detail="Still processing") if entry["result"] is None: raise HTTPException(status_code=500, detail=entry["status"]) return entry["result"]