client_reco_test.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import argparse
  2. import io
  3. import mimetypes
  4. import os
  5. import time
  6. import zipfile
  7. import requests
  8. def upload(base_url: str, path: str) -> str:
  9. url = f"{base_url}/upload"
  10. filename = os.path.basename(path)
  11. mime = mimetypes.guess_type(filename)[0] or "application/octet-stream"
  12. with open(path, "rb") as f:
  13. files = {"file": (filename, f, mime)}
  14. r = requests.post(url, files=files, timeout=60)
  15. r.raise_for_status()
  16. file_id = r.json()["file_id"]
  17. return file_id
  18. def start_session(
  19. base_url: str,
  20. file_raw_id: str,
  21. file_json_id: str,
  22. sequence_name: str,
  23. digit: str,
  24. phase_shift: bool,
  25. file_order_id: str | None = None,
  26. ) -> str:
  27. url = f"{base_url}/h5/reconstruct"
  28. payload = {
  29. "file_raw_id": file_raw_id,
  30. "file_json_id": file_json_id,
  31. "file_order_id": file_order_id, # можно None
  32. "sequence_name": sequence_name,
  33. "digit": digit,
  34. "phase_shift": phase_shift,
  35. }
  36. r = requests.post(url, json=payload, timeout=30)
  37. r.raise_for_status()
  38. return r.json()["session_id"]
  39. def poll_status(base_url: str, session_id: str, interval_s: float = 0.5) -> dict:
  40. url = f"{base_url}/sessions/{session_id}"
  41. last = None
  42. while True:
  43. r = requests.get(url, timeout=15)
  44. r.raise_for_status()
  45. st = r.json()
  46. if st != last:
  47. p = int(st.get("progress", 0) * 100)
  48. msg = st.get("message", "")
  49. print(f"\r[{st['status']:<7}] {p:3d}% {msg:40s}", end="", flush=True)
  50. last = st
  51. if st["status"] in ("done", "error"):
  52. print()
  53. return st
  54. time.sleep(interval_s)
  55. def download_and_extract(base_url: str, session_id: str, out_dir: str) -> str:
  56. os.makedirs(out_dir, exist_ok=True)
  57. url = f"{base_url}/sessions/{session_id}/archive.zip"
  58. r = requests.get(url, timeout=120)
  59. r.raise_for_status()
  60. buf = io.BytesIO(r.content)
  61. with zipfile.ZipFile(buf) as zf:
  62. zf.extractall(out_dir)
  63. return out_dir
  64. def main():
  65. ap = argparse.ArgumentParser(description="Client for MRI Reconstruction API")
  66. ap.add_argument("--url", default="http://127.0.0.1:8000", help="base URL of the API")
  67. ap.add_argument("--raw", required=True, help="path to raw .mat/.h5")
  68. ap.add_argument("--params", required=True, help="path to params.json")
  69. ap.add_argument("--order", help="path to order.json (for nonlinear/radial)")
  70. ap.add_argument("--seq", required=True,
  71. choices=["linear_decart", "nonlinear_decart(tse)", "linear_epi", "radial_propeller"])
  72. ap.add_argument("--digit", required=True, choices=["2d", "3d"])
  73. ap.add_argument("--phase-shift", action="store_true", help="use RF_spoil from JSON (otherwise 0)")
  74. ap.add_argument("--out", default="results", help="output directory for images")
  75. args = ap.parse_args()
  76. base_url = args.url.rstrip("/")
  77. print("Uploading files...")
  78. raw_id = upload(base_url, args.raw)
  79. json_id = upload(base_url, args.params)
  80. order_id = upload(base_url, args.order) if args.order else None
  81. print(f"RAW id: {raw_id}")
  82. print(f"JSON id: {json_id}")
  83. if order_id:
  84. print(f"ORDER id: {order_id}")
  85. print("Starting session...")
  86. session_id = start_session(
  87. base_url=base_url,
  88. file_raw_id=raw_id,
  89. file_json_id=json_id,
  90. file_order_id=order_id,
  91. sequence_name=args.seq,
  92. digit=args.digit,
  93. phase_shift=bool(args.phase_shift),
  94. )
  95. print(f"Session: {session_id}")
  96. print("Waiting for completion...")
  97. st = poll_status(base_url, session_id)
  98. if st["status"] == "error":
  99. print("\n--- ERROR TRACEBACK ---")
  100. print(st.get("error_traceback", ""))
  101. raise SystemExit(1)
  102. out_dir = os.path.join(args.out, session_id)
  103. download_and_extract(base_url, session_id, out_dir)
  104. r = requests.get(f"{base_url}/sessions/{session_id}/files", timeout=30)
  105. r.raise_for_status()
  106. files = r.json().get("files", [])
  107. print(f"\nSaved to: {os.path.abspath(out_dir)}")
  108. if files:
  109. print("Files:")
  110. for name in files:
  111. print(" ", name)
  112. if __name__ == "__main__":
  113. main()