| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- import io, os, json, zipfile, traceback
- from fastapi import APIRouter, BackgroundTasks, HTTPException
- from fastapi.responses import FileResponse, StreamingResponse
- import matplotlib.pyplot as plt
- from ..core import EXECUTOR
- from ..schemas import StartSessionRequest, SessionStatus
- from ..session import Session, SESSIONS
- from ..storage import FILES
- from ..utils import SavefigRedirect, collect_outputs
- from ..recon_app import ReconstructionApp, ReconstructionH5
- router = APIRouter()
- @router.post("/reconstruct", response_model=SessionStatus)
- def reconstruct(req: StartSessionRequest, background: BackgroundTasks):
- if req.file_raw_id not in FILES: raise HTTPException(400, "file_raw_id не найден")
- if req.file_json_id not in FILES: raise HTTPException(400, "file_json_id не найден")
- if req.file_order_id and req.file_order_id not in FILES: raise HTTPException(400, "file_order_id не найден")
- session = Session(
- FILES[req.file_raw_id],
- FILES[req.file_json_id],
- FILES.get(req.file_order_id) if req.file_order_id else None,
- req.sequence_name, req.digit, req.phase_shift
- )
- SESSIONS[session.session_id] = session
- def _run():
- savefig_orig = plt.savefig
- try:
- session.set(status="running", progress=0.1, message="init")
- with SavefigRedirect(session.work_dir):
- app_reco = ReconstructionApp(name=session.sequence_name, digit=session.digit, shift=session.phase_shift)
- session.set(progress=0.2, message="read + prepare")
- app_reco.start_reconstruction(
- path_raw_data=session.file_raw,
- path_np_data_json=session.file_json,
- path_order_json=session.file_order if session.file_order else session.file_json
- )
- session.set(progress=0.95, message="collect results")
- session.result_files = collect_outputs(session.work_dir)
- session.set(status="done", progress=1.0, message="done")
- except Exception:
- session.error_traceback = traceback.format_exc()
- session.set(status="error", message="error")
- finally:
- plt.savefig = savefig_orig
- with open(os.path.join(session.work_dir, "status.json"), "w", encoding="utf-8") as f:
- json.dump(session.to_status().model_dump(), f, ensure_ascii=False, indent=2)
- background.add_task(EXECUTOR.submit, _run)
- return session.to_status()
- @router.post("/h5/reconstruct", response_model=SessionStatus)
- def reconstructFromH5(req: StartSessionRequest, background: BackgroundTasks):
- if req.file_raw_id not in FILES: raise HTTPException(400, "file_raw_id не найден")
- if req.file_json_id not in FILES: raise HTTPException(400, "file_json_id не найден")
- if req.file_order_id and req.file_order_id not in FILES: raise HTTPException(400, "file_order_id не найден")
-
- session = Session(
- FILES[req.file_raw_id],
- FILES[req.file_json_id],
- FILES.get(req.file_order_id) if req.file_order_id else None,
- req.sequence_name, req.digit, req.phase_shift
- )
- SESSIONS[session.session_id] = session
-
- def _run():
- savefig_orig = plt.savefig
- try:
- session.set(status="running", progress=0.1, message="init")
-
- with SavefigRedirect(session.work_dir):
- app_reco = ReconstructionH5(name=session.sequence_name, digit=session.digit, shift=session.phase_shift)
- session.set(progress=0.2, message="read + prepare")
- app_reco.start_reconstruction(
- path_raw_data=session.file_raw,
- path_np_data_json=session.file_json,
- path_order_json=session.file_order if session.file_order else session.file_json
- )
- session.set(progress=0.95, message="collect results")
- session.result_files = collect_outputs(session.work_dir)
- session.set(status="done", progress=1.0, message="done")
- except Exception:
- session.error_traceback = traceback.format_exc()
- session.set(status="error", message="error")
- finally:
- plt.savefig = savefig_orig
- with open(os.path.join(session.work_dir, "status.json"), "w", encoding="utf-8") as f:
- json.dump(session.to_status().model_dump(), f, ensure_ascii=False, indent=2)
-
- background.add_task(EXECUTOR.submit, _run)
- return session.to_status()
- @router.get("/sessions/{session_id}", response_model=SessionStatus)
- def get_status(session_id: str):
- session = SESSIONS.get(session_id)
- if not session: raise HTTPException(404, "не найден")
- return session.to_status()
- @router.get("/sessions/{session_id}/files")
- def list_files(session_id: str):
- session = SESSIONS.get(session_id)
- if not session: raise HTTPException(404, "не найден")
- files = [os.path.basename(p) for p in session.result_files if os.path.isfile(p)]
- return {"files": files}
- @router.get("/sessions/{session_id}/files/{name}")
- def download_file(session_id: str, name: str):
- session = SESSIONS.get(session_id)
- if not session: raise HTTPException(404, "не найден")
- target = os.path.join(session.work_dir, name)
- if not os.path.isfile(target): raise HTTPException(404, "файл не найден")
- return FileResponse(target, filename=name)
- @router.get("/sessions/{session_id}/archive.zip")
- def download_zip(session_id: str):
- session = SESSIONS.get(session_id)
- if not session: raise HTTPException(404, "не найден")
- mem = io.BytesIO()
- with zipfile.ZipFile(mem, "w", zipfile.ZIP_DEFLATED) as zf:
- for p in session.result_files:
- if os.path.isfile(p):
- zf.write(p, arcname=os.path.basename(p))
- zf.writestr("status.json", json.dumps(session.to_status().model_dump(), ensure_ascii=False, indent=2))
- mem.seek(0)
- return StreamingResponse(mem, media_type="application/zip",
- headers={"Content-Disposition": f'attachment; filename=\"%s.zip\"' % session.session_id})
- @router.delete("/sessions/{session_id}")
- def delete_session(session_id: str):
- from shutil import rmtree
- session = SESSIONS.pop(session_id, None)
- if not session: raise HTTPException(404, "не найден")
- rmtree(session.work_dir, ignore_errors=True)
- return {"deleted": session_id}
|