""" seq_interp FastAPI service. Endpoints: GET /health - liveness probe POST /interpret/ - upload .seq file, run full pipeline, return task_id GET /status/ - status of all submitted tasks GET /result/{task_id} - full result: xml_text, post_json, metadata, waveforms """ from __future__ import annotations import asyncio import json import os import shutil import uuid from typing import Any from fastapi import FastAPI, File, HTTPException, UploadFile from src.config import config from src.hardware.constraints import HardwareConstraints from src.interfaces.pulseq_adapter import PulseqLoader from src.core.synchronizer import Synchronizer from src.interfaces.xml_generator import XMLGenerator from src.interfaces.rf_exporter import RFExporter from src.interfaces.gradient_exporter import GradientExporter from src.interfaces.picoscope_exporter import PicoScopeExporter from src.interfaces.post_request_generator import PostRequestGenerator app = FastAPI(title="seq-interp", version="1.0.0") UPLOAD_DIR = config.get("upload_dir", "data/input") OUTPUT_DIR = config.get("output_dir", "data/output") os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # in-memory cache: task_id -> {"status": str, "result": dict | None} _tasks: dict[str, dict] = {} _MAX_WAVEFORM_POINTS = 8000 # downsample for JSON transport _STATUS_FILE = "_status.json" _RESULT_FILE = "_result.json" # -- disk persistence -------------------------------------------------------- def _task_dir(task_id: str) -> str: return os.path.join(OUTPUT_DIR, task_id) def _persist_task(task_id: str, entry: dict) -> None: """Write task status (and result if present) to disk.""" d = _task_dir(task_id) os.makedirs(d, exist_ok=True) with open(os.path.join(d, _STATUS_FILE), "w", encoding="utf-8") as f: json.dump({"status": entry["status"]}, f) if entry.get("result") is not None: with open(os.path.join(d, _RESULT_FILE), "w", encoding="utf-8") as f: json.dump(entry["result"], f, ensure_ascii=False) def _load_task_from_disk(task_id: str) -> dict | None: """Restore task entry from disk (used after container restart).""" d = _task_dir(task_id) status_path = os.path.join(d, _STATUS_FILE) result_path = os.path.join(d, _RESULT_FILE) if not os.path.exists(status_path): return None with open(status_path, encoding="utf-8") as f: status = json.load(f).get("status", "unknown") result = None if os.path.exists(result_path): with open(result_path, encoding="utf-8") as f: result = json.load(f) return {"status": status, "result": result} # -- helpers ---------------------------------------------------------------- def _downsample(arr, max_pts: int) -> list: """Convert numpy array to list, downsampling if too long.""" try: import numpy as np a = np.asarray(arr, dtype=float).flatten() if len(a) > max_pts: idx = np.linspace(0, len(a) - 1, max_pts, dtype=int) a = a[idx] return [float(x) for x in a] except Exception: return [] def _extract_waveforms(seq_data: dict, sync_data: dict) -> dict: """Extract waveform arrays from seq_data for GUI display.""" try: import numpy as _np except ImportError: _np = None waveforms: dict[str, list] = {} # Gradients for key in ("gx", "gy", "gz", "t_gx", "t_gy", "t_gz"): if key in seq_data: waveforms[key] = _downsample(seq_data[key], _MAX_WAVEFORM_POINTS) # RF — seq_data stores a complex array under "rf" / "t_rf". # Split into amplitude and phase for JSON transport (complex is not serialisable). if "rf" in seq_data and "t_rf" in seq_data and _np is not None: rf = _np.asarray(seq_data["rf"]) waveforms["rf_amp"] = _downsample(_np.abs(rf), _MAX_WAVEFORM_POINTS) waveforms["rf_phase"] = _downsample(_np.angle(rf), _MAX_WAVEFORM_POINTS) waveforms["t_rf"] = _downsample(seq_data["t_rf"], _MAX_WAVEFORM_POINTS) # Sync gate arrays for key in ("gate_adc", "gate_rf", "gate_tr_switch", "blocks_duration"): if key in sync_data: waveforms[key] = _downsample(sync_data[key], _MAX_WAVEFORM_POINTS) return waveforms async def _run_pipeline(file_path: str, task_id: str) -> None: """Run the full interpretation pipeline and store result in _tasks.""" out_dir = os.path.join(OUTPUT_DIR, task_id) os.makedirs(out_dir, exist_ok=True) try: hw = HardwareConstraints(json_path=config.get("hw_config_path")) hw_cfg = config.hw_config loader = PulseqLoader(hw) seq_data = await asyncio.to_thread(loader.load, file_path) params = seq_data.get("params", {}) sync = Synchronizer(hw) sync_data = await asyncio.to_thread(sync.process, seq_data["sequence"]) xml_gen = XMLGenerator() xml_path = os.path.join(out_dir, "sync_v2.xml") adc_values, adc_starts = await asyncio.to_thread( xml_gen.generate, sync_data, xml_path, hw ) with open(xml_path, encoding="utf-8") as fh: xml_text = fh.read() export_tasks = [ asyncio.to_thread(RFExporter().export, seq_data, params, out_dir) ] if all(k in seq_data for k in ("gx", "gy", "gz")): export_tasks.append( asyncio.to_thread(GradientExporter().export, seq_data, params, out_dir) ) iadc = hw_cfg.get("iadc", {}) export_tasks.append(asyncio.to_thread( PicoScopeExporter().generate, adc_values, adc_starts, out_dir, hw, sampling_freq=iadc.get("srate", 8e6), num_channels=iadc.get("n_channels", 3), )) await asyncio.gather(*export_tasks) post_gen = PostRequestGenerator() post_payload = post_gen.build( seq_data=seq_data, adc_values=adc_values, sequence_path=file_path, output_dir=out_dir, hw_cfg=hw_cfg, rf_raster_time=params.get("rf_raster_time", 1e-6), ) post_gen.write(post_payload, out_dir) blocks = seq_data.get("blocks", []) total_s = sum(sync_data.get("blocks_duration", [])) adc_blocks = [b for b in blocks if b.get("has_adc")] result: dict[str, Any] = { "task_id": task_id, "status": "completed", "output_dir": out_dir, "xml_text": xml_text, "post_json": post_payload, "metadata": { "block_count": len(blocks), "sync_block_count": sync_data.get("number_of_blocks", 0), "adc_count": len(adc_blocks), "adc_windows": len(adc_values), "total_duration_ms": round(total_s * 1e3, 4), "rf_raster_us": params.get("rf_raster_time", 1e-6) * 1e6, "grad_raster_us": params.get("grad_raster_time", 1e-5) * 1e6, }, "waveforms": _extract_waveforms(seq_data, sync_data), } entry = {"status": "completed", "result": result} _tasks[task_id] = entry _persist_task(task_id, entry) except Exception as exc: entry = {"status": f"failed: {exc}", "result": None} _tasks[task_id] = entry _persist_task(task_id, entry) # -- endpoints ---------------------------------------------------------------- @app.get("/health") def health(): return {"status": "ok"} @app.post("/interpret/") async def interpret_endpoint(file: UploadFile = File(...)): """Upload a .seq file and run the full interpretation pipeline.""" task_id = str(uuid.uuid4()) # Store the file under the task_id so parallel uploads never collide task_upload_dir = os.path.join(UPLOAD_DIR, task_id) os.makedirs(task_upload_dir, exist_ok=True) file_path = os.path.join(task_upload_dir, file.filename) with open(file_path, "wb") as buf: shutil.copyfileobj(file.file, buf) entry: dict = {"status": "processing", "result": None} _tasks[task_id] = entry _persist_task(task_id, entry) asyncio.create_task(_run_pipeline(file_path, task_id)) return {"status": "accepted", "task_id": task_id, "message": f"Processing {file.filename}"} @app.get("/status/") def status_endpoint(): """Return the status of all submitted tasks.""" return {"tasks": {tid: v["status"] for tid, v in _tasks.items()}} @app.get("/result/{task_id}") def result_endpoint(task_id: str): """Return full interpretation result (xml_text, post_json, metadata, waveforms).""" entry = _tasks.get(task_id) if entry is None: # Container may have restarted — try restoring from disk entry = _load_task_from_disk(task_id) if entry is None: raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") _tasks[task_id] = entry # restore to memory cache if entry["status"] == "processing": raise HTTPException(status_code=202, detail="Still processing") if entry["result"] is None: raise HTTPException(status_code=500, detail=entry["status"]) return entry["result"]