app.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import time
  2. import os
  3. import io
  4. import zipfile
  5. from pathlib import Path
  6. from typing import List, Tuple
  7. from PIL import Image
  8. import streamlit as st
  9. st.set_page_config(
  10. page_title="MRI physics based augmentation",
  11. page_icon="🧠",
  12. layout="wide"
  13. )
  14. # ---------- Simple router in session_state ----------
  15. if "page" not in st.session_state:
  16. st.session_state.page = "home"
  17. # storage for generated phantom (appears after progress completes)
  18. if "phantom_blob" not in st.session_state:
  19. st.session_state.phantom_blob = None
  20. if "phantom_name" not in st.session_state:
  21. st.session_state.phantom_name = None
  22. def nav_to(p: str):
  23. st.session_state.page = p
  24. st.title("MRI physics based augmentation")
  25. # ---------- Theme-aware logo (top-right, base64-embedded) ----------
  26. import base64
  27. def _b64_img(path: str) -> str | None:
  28. try:
  29. with open(path, "rb") as f:
  30. return base64.b64encode(f.read()).decode("ascii")
  31. except Exception:
  32. return None
  33. def render_theme_logo(light_path: str = "logos/NEW_PHYSTECH_for_light.png",
  34. dark_path: str = "logos/NEW_PHYSTECH_for_dark.png",
  35. size_px: int = 250):
  36. """
  37. Рисует фиксированный логотип справа сверху, меняющийся по теме.
  38. Картинки вшиваются через base64, чтобы избежать проблем с путями.
  39. """
  40. light_b64 = _b64_img(light_path)
  41. dark_b64 = _b64_img(dark_path)
  42. # Если нет файлов, просто ничего не рисуем (и даём подсказку внизу страницы)
  43. if not light_b64 and not dark_b64:
  44. return
  45. light_src = f"data:image/png;base64,{light_b64}" if light_b64 else ""
  46. dark_src = f"data:image/png;base64,{dark_b64}" if dark_b64 else light_src
  47. html = f"""
  48. <style>
  49. .theme-logo-topright {{
  50. position: fixed; top: 10px; right: 16px; z-index: 9999;
  51. }}
  52. .theme-logo-topright img {{
  53. width: {size_px}px; height: {size_px}px; object-fit: contain;
  54. border-radius: 8px;
  55. }}
  56. /* скрываем одну из картинок в зависимости от темы ОС/браузера */
  57. .theme-logo-topright img.light {{ display: inline-block; }}
  58. .theme-logo-topright img.dark {{ display: none; }}
  59. @media (prefers-color-scheme: dark) {{
  60. .theme-logo-topright img.light {{ display: none; }}
  61. .theme-logo-topright img.dark {{ display: inline-block; }}
  62. }}
  63. @media (max-width: 480px) {{
  64. .theme-logo-topright img {{ width: {int(size_px*0.85)}px; height: {int(size_px*0.85)}px; }}
  65. }}
  66. </style>
  67. <div class="theme-logo-topright">
  68. <img src="{light_src}" alt="logo" class="light" />
  69. <img src="{dark_src}" alt="logo" class="dark" />
  70. </div>
  71. """
  72. st.markdown(html, unsafe_allow_html=True)
  73. render_theme_logo()
  74. # ---------- Image helpers ----------
  75. SUPPORTED_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")
  76. SUPPORTED_EXTS_PHANTOM = (".dcm", ".nii", ".nii.gz", ".nrrd", ".npy", ".png", ".jpg", ".jpeg")
  77. MAX_VALUE_DATASET = 100000
  78. def center_crop_to_square(img: Image.Image) -> Image.Image:
  79. """Center-crop PIL image to a square based on the smaller side."""
  80. w, h = img.size
  81. s = min(w, h)
  82. left = (w - s) // 2
  83. top = (h - s) // 2
  84. return img.crop((left, top, left + s, top + s))
  85. def load_and_prepare_assets(asset_dir: str = "assets", count: int = 3, size: Tuple[int, int] = (320, 320)) -> List[Tuple[Image.Image, str]]:
  86. """Load up to `count` images from asset_dir, center-crop to square, resize to `size`."""
  87. results = []
  88. if not os.path.isdir(asset_dir):
  89. return results
  90. files = sorted([f for f in os.listdir(asset_dir) if os.path.splitext(f.lower())[1] in SUPPORTED_EXTS])
  91. for fname in files[:count]:
  92. path = os.path.join(asset_dir, fname)
  93. try:
  94. img = Image.open(path).convert("RGB")
  95. img = center_crop_to_square(img)
  96. img = img.resize(size, Image.LANCZOS)
  97. results.append((img, fname))
  98. except Exception:
  99. continue
  100. return results
  101. def run_job_stub(status_placeholder, progress_placeholder, steps=None, delay=0.9):
  102. """Simulate a long-running job with progress and 3-line status stream.
  103. Returns True when finished."""
  104. if steps is None:
  105. steps = [
  106. "Инициализация пайплайна...",
  107. "Обработка входных данных...",
  108. "Генерация синтетических изображений...",
  109. "Постобработка результатов...",
  110. "Готово!",
  111. ]
  112. progress = 0
  113. last3 = []
  114. progress_placeholder.progress(progress, text="Waiting to start...")
  115. status_placeholder.markdown("")
  116. for i, msg in enumerate(steps, 1):
  117. last3.append(msg)
  118. last3 = last3[-3:] # keep only the last 3
  119. # Newest at the top for the "falls down" feel:
  120. lines = []
  121. for idx, line in enumerate(reversed(last3)):
  122. if idx == 0:
  123. lines.append(f"- **{line}**")
  124. elif idx == 1:
  125. lines.append(f"- <span style='opacity:0.7'>{line}</span>")
  126. else:
  127. lines.append(f"- <span style='opacity:0.45'>{line}</span>")
  128. status_placeholder.markdown("<br/>".join(lines), unsafe_allow_html=True)
  129. progress = int(i * 100 / len(steps))
  130. progress_placeholder.progress(progress, text=f"Progress: {progress}%")
  131. time.sleep(delay)
  132. return True
  133. def make_demo_zip() -> bytes:
  134. """Create a demo ZIP to download as 'phantom' result."""
  135. buf = io.BytesIO()
  136. with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
  137. zf.writestr("phantom/README.txt", "Demo phantom result. Replace with real generated files.")
  138. buf.seek(0)
  139. return buf.getvalue()
  140. # ---------- Pages ----------
  141. def page_home():
  142. st.subheader("Информация о контрастах и генерации фантомов")
  143. st.write(
  144. "Здесь можно разместить описание ИП (T1/T2/PD), принципы формирования данных и примеры сканов. "
  145. "Картинки ниже подгружаются из папки `assets/`, **автоматически центр-обрезаются в квадрат** и "
  146. "**приводятся к одинаковому размеру** (320×320)."
  147. )
  148. # Load 3 prepared images from assets
  149. images = load_and_prepare_assets("assets", count=3, size=(320, 320))
  150. if images:
  151. cols = st.columns(len(images))
  152. for (img, name), col in zip(images, cols):
  153. with col:
  154. st.image(img, caption=name, use_container_width=False)
  155. else:
  156. st.info("Положите 1–3 изображения в папку `assets/` (png/jpg/tif), и они появятся здесь одинакового размера.")
  157. st.markdown("---")
  158. c1, c2 = st.columns(2)
  159. with c1:
  160. with st.container(border=True):
  161. st.markdown("#### 🧠 Phantom generation")
  162. st.write("Upload T1/T2/PD images and begin the generation.")
  163. st.button("Move to the phantom", type="primary", use_container_width=True, on_click=nav_to, args=("phantom",))
  164. with c2:
  165. with st.container(border=True):
  166. st.markdown("#### 📦 Dataset generation")
  167. st.write("Generation of a dataset based on pulse sequence parameters.")
  168. st.button("Move to the dataset", type="primary", use_container_width=True, on_click=nav_to, args=("dataset",))
  169. def page_phantom():
  170. st.button("← Homepage", on_click=nav_to, args=("home",))
  171. st.subheader("Generate the phantom")
  172. st.caption("Please upload T1/T2/PD images of a brain scan")
  173. c1, c2, c3 = st.columns(3)
  174. with c1:
  175. t1_file = st.file_uploader("T1", type=SUPPORTED_EXTS_PHANTOM)
  176. with c2:
  177. t2_file = st.file_uploader("T2", type=SUPPORTED_EXTS_PHANTOM)
  178. with c3:
  179. pd_file = st.file_uploader("PD", type=SUPPORTED_EXTS_PHANTOM)
  180. start_btn = st.button("Begin generation", type="primary")
  181. progress_ph = st.empty()
  182. statuses_ph = st.empty()
  183. if start_btn:
  184. if not (t1_file and t2_file and pd_file):
  185. st.error("Нужно загрузить все три файла: T1, T2, PD.")
  186. else:
  187. done = run_job_stub(statuses_ph, progress_ph)
  188. if done:
  189. # Save demo result into session_state for download button
  190. st.session_state.phantom_blob = make_demo_zip()
  191. st.session_state.phantom_name = "phantom_demo.zip"
  192. st.success("Готово! Можно скачать результат.")
  193. # The download button appears only when phantom_blob is present (after job completes)
  194. if st.session_state.phantom_blob:
  195. st.download_button(
  196. "Download phantom",
  197. data=st.session_state.phantom_blob,
  198. file_name=st.session_state.phantom_name or "phantom_demo.zip",
  199. mime="application/zip",
  200. use_container_width=False,
  201. type="primary",
  202. )
  203. def page_dataset():
  204. st.button("← Homepage", on_click=nav_to, args=("home",))
  205. st.subheader("Generate the dataset")
  206. st.caption("Dataset generation. Please choose parameters for the pulse sequence.")
  207. c1, c2, c3 = st.columns(3)
  208. with c1:
  209. t1d_file = st.file_uploader("T1", key="t1d", type=SUPPORTED_EXTS_PHANTOM)
  210. with c2:
  211. t2d_file = st.file_uploader("T2", key="t2d", type=SUPPORTED_EXTS_PHANTOM)
  212. with c3:
  213. pdd_file = st.file_uploader("PD", key="pdd", type=SUPPORTED_EXTS_PHANTOM)
  214. count = st.number_input("Сколько примеров сгенерировать", min_value=1, max_value=MAX_VALUE_DATASET, value=50, step=1)
  215. start_ds = st.button("Начать генерацию датасета", type="primary")
  216. progress2 = st.empty()
  217. status2 = st.empty()
  218. if start_ds:
  219. if not (t1d_file and t2d_file and pdd_file):
  220. st.error("Нужно загрузить все три файла: T1, T2, PD.")
  221. else:
  222. run_job_stub(status2, progress2)
  223. st.success(f"Датасет сформирован (демо). Количество: {int(count)}")
  224. # ---------- Router ----------
  225. if st.session_state.page == "home":
  226. page_home()
  227. elif st.session_state.page == "phantom":
  228. page_phantom()
  229. elif st.session_state.page == "dataset":
  230. page_dataset()
  231. else:
  232. st.session_state.page = "home"
  233. page_home()