| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- import argparse
- import io
- import mimetypes
- import os
- import time
- import zipfile
- import requests
- def upload(base_url: str, path: str) -> str:
- url = f"{base_url}/upload"
- filename = os.path.basename(path)
- mime = mimetypes.guess_type(filename)[0] or "application/octet-stream"
- with open(path, "rb") as f:
- files = {"file": (filename, f, mime)}
- r = requests.post(url, files=files, timeout=60)
- r.raise_for_status()
- file_id = r.json()["file_id"]
- return file_id
- def start_session(
- base_url: str,
- file_raw_id: str,
- file_json_id: str,
- sequence_name: str,
- digit: str,
- phase_shift: bool,
- file_order_id: str | None = None,
- ) -> str:
- url = f"{base_url}/h5/reconstruct"
- payload = {
- "file_raw_id": file_raw_id,
- "file_json_id": file_json_id,
- "file_order_id": file_order_id, # можно None
- "sequence_name": sequence_name,
- "digit": digit,
- "phase_shift": phase_shift,
- }
- r = requests.post(url, json=payload, timeout=30)
- r.raise_for_status()
- return r.json()["session_id"]
- def poll_status(base_url: str, session_id: str, interval_s: float = 0.5) -> dict:
- url = f"{base_url}/sessions/{session_id}"
- last = None
- while True:
- r = requests.get(url, timeout=15)
- r.raise_for_status()
- st = r.json()
- if st != last:
- p = int(st.get("progress", 0) * 100)
- msg = st.get("message", "")
- print(f"\r[{st['status']:<7}] {p:3d}% {msg:40s}", end="", flush=True)
- last = st
- if st["status"] in ("done", "error"):
- print()
- return st
- time.sleep(interval_s)
- def download_and_extract(base_url: str, session_id: str, out_dir: str) -> str:
- os.makedirs(out_dir, exist_ok=True)
- url = f"{base_url}/sessions/{session_id}/archive.zip"
- r = requests.get(url, timeout=120)
- r.raise_for_status()
- buf = io.BytesIO(r.content)
- with zipfile.ZipFile(buf) as zf:
- zf.extractall(out_dir)
- return out_dir
- def main():
- ap = argparse.ArgumentParser(description="Client for MRI Reconstruction API")
- ap.add_argument("--url", default="http://127.0.0.1:8000", help="base URL of the API")
- ap.add_argument("--raw", required=True, help="path to raw .mat/.h5")
- ap.add_argument("--params", required=True, help="path to params.json")
- ap.add_argument("--order", help="path to order.json (for nonlinear/radial)")
- ap.add_argument("--seq", required=True,
- choices=["linear_decart", "nonlinear_decart(tse)", "linear_epi", "radial_propeller"])
- ap.add_argument("--digit", required=True, choices=["2d", "3d"])
- ap.add_argument("--phase-shift", action="store_true", help="use RF_spoil from JSON (otherwise 0)")
- ap.add_argument("--out", default="results", help="output directory for images")
- args = ap.parse_args()
- base_url = args.url.rstrip("/")
- print("Uploading files...")
- raw_id = upload(base_url, args.raw)
- json_id = upload(base_url, args.params)
- order_id = upload(base_url, args.order) if args.order else None
- print(f"RAW id: {raw_id}")
- print(f"JSON id: {json_id}")
- if order_id:
- print(f"ORDER id: {order_id}")
- print("Starting session...")
- session_id = start_session(
- base_url=base_url,
- file_raw_id=raw_id,
- file_json_id=json_id,
- file_order_id=order_id,
- sequence_name=args.seq,
- digit=args.digit,
- phase_shift=bool(args.phase_shift),
- )
- print(f"Session: {session_id}")
- print("Waiting for completion...")
- st = poll_status(base_url, session_id)
- if st["status"] == "error":
- print("\n--- ERROR TRACEBACK ---")
- print(st.get("error_traceback", ""))
- raise SystemExit(1)
- out_dir = os.path.join(args.out, session_id)
- download_and_extract(base_url, session_id, out_dir)
- r = requests.get(f"{base_url}/sessions/{session_id}/files", timeout=30)
- r.raise_for_status()
- files = r.json().get("files", [])
- print(f"\nSaved to: {os.path.abspath(out_dir)}")
- if files:
- print("Files:")
- for name in files:
- print(" ", name)
- if __name__ == "__main__":
- main()
|