app.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import time
  2. import os
  3. import io
  4. import zipfile
  5. from pathlib import Path
  6. from typing import List, Tuple
  7. from PIL import Image
  8. import streamlit as st
  9. import base64
  10. def nav_to(p: str):
  11. st.session_state.page = p
  12. # ---------- Theme-aware logo (top-right, base64-embedded) ----------
  13. def _b64_img(path: str) -> str | None:
  14. try:
  15. with open(path, "rb") as f:
  16. return base64.b64encode(f.read()).decode("ascii")
  17. except Exception:
  18. return None
  19. def header_with_theme_logo(title: str,
  20. light_path: str = "logos/NEW_PHYSTECH_for_light.png",
  21. dark_path: str = "logos/NEW_PHYSTECH_for_dark.png",
  22. size_px: int = 100):
  23. """
  24. Display a header with a theme-aware logo and title.
  25. Made in HTML/CSS to adapt to light/dark mode in Streamlit.
  26. Parameters
  27. ----------
  28. title : str
  29. The title to display in the header.
  30. light_path : str
  31. Path to the light theme logo image.
  32. dark_path : str
  33. Path to the dark theme logo image.
  34. size_px : int
  35. Size of the logo in pixels (height).
  36. Returns
  37. -------
  38. None
  39. """
  40. light_b64 = _b64_img(light_path)
  41. dark_b64 = _b64_img(dark_path)
  42. if not (light_b64 or dark_b64):
  43. st.markdown(f"## {title}")
  44. return
  45. light_src = f"data:image/png;base64,{light_b64}" if light_b64 else ""
  46. dark_src = f"data:image/png;base64,{dark_b64}" if dark_b64 else light_src
  47. html = f"""
  48. <style>
  49. section.main > div:first-child {{
  50. padding-top: 0rem;
  51. }}
  52. .hdr {{
  53. display: flex; align-items: center; justify-content: space-between;
  54. }}
  55. .hdr h1 {{
  56. margin: 0;
  57. font-size: 3.5rem;
  58. line-height: 1.2;
  59. }}
  60. .hdr .logo img {{
  61. width: 250px; height: {size_px}px; object-fit: contain;
  62. border-radius: 8px;
  63. display: inline-block;
  64. }}
  65. .hdr .logo img.light {{ display: inline-block; }}
  66. .hdr .logo img.dark {{ display: none; }}
  67. @media (prefers-color-scheme: dark) {{
  68. .hdr .logo img.light {{ display: none; }}
  69. .hdr .logo img.dark {{ display: inline-block; }}
  70. }}
  71. </style>
  72. <div class="hdr">
  73. <h1>{title}</h1>
  74. <div class="logo">
  75. <img src="{light_src}" alt="logo" class="light" />
  76. <img src="{dark_src}" alt="logo" class="dark" />
  77. </div>
  78. </div>
  79. """
  80. st.markdown(html, unsafe_allow_html=True)
  81. st.set_page_config(
  82. page_title="MRI physics based augmentation",
  83. page_icon="🧠",
  84. layout="wide"
  85. )
  86. header_with_theme_logo("MRI physics based augmentation")
  87. # ---------- Simple router in session_state ----------
  88. if "page" not in st.session_state:
  89. st.session_state.page = "home"
  90. # storage for generated phantom (appears after progress completes)
  91. if "phantom_blob" not in st.session_state:
  92. st.session_state.phantom_blob = None
  93. if "phantom_name" not in st.session_state:
  94. st.session_state.phantom_name = None
  95. # ---------- Image helpers ----------
  96. SUPPORTED_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")
  97. SUPPORTED_EXTS_PHANTOM = (".dcm", ".nii", ".nii.gz", ".nrrd", ".npy", ".png", ".jpg", ".jpeg")
  98. MAX_VALUE_DATASET = 100000
  99. def center_crop_to_square(img: Image.Image) -> Image.Image:
  100. """Center-crop PIL image to a square based on the smaller side."""
  101. w, h = img.size
  102. s = min(w, h)
  103. left = (w - s) // 2
  104. top = (h - s) // 2
  105. return img.crop((left, top, left + s, top + s))
  106. def load_and_prepare_assets(asset_dir: str = "assets", count: int = 3, size: Tuple[int, int] = (320, 320)) -> List[Tuple[Image.Image, str]]:
  107. """Load up to `count` images from asset_dir, center-crop to square, resize to `size`."""
  108. results = []
  109. if not os.path.isdir(asset_dir):
  110. return results
  111. files = sorted([f for f in os.listdir(asset_dir) if os.path.splitext(f.lower())[1] in SUPPORTED_EXTS])
  112. for fname in files[:count]:
  113. path = os.path.join(asset_dir, fname)
  114. try:
  115. img = Image.open(path).convert("RGB")
  116. img = center_crop_to_square(img)
  117. img = img.resize(size, Image.LANCZOS)
  118. results.append((img, fname))
  119. except Exception:
  120. continue
  121. return results
  122. def run_job_stub(status_placeholder, progress_placeholder, steps=None, delay=0.9):
  123. """Simulate a long-running job with progress and 3-line status stream.
  124. Returns True when finished."""
  125. if steps is None:
  126. steps = [
  127. "Инициализация пайплайна...",
  128. "Обработка входных данных...",
  129. "Генерация синтетических изображений...",
  130. "Постобработка результатов...",
  131. "Готово!",
  132. ]
  133. progress = 0
  134. last3 = []
  135. progress_placeholder.progress(progress, text="Waiting to start...")
  136. status_placeholder.markdown("")
  137. for i, msg in enumerate(steps, 1):
  138. last3.append(msg)
  139. last3 = last3[-3:] # keep only the last 3
  140. # Newest at the top for the "falls down" feel:
  141. lines = []
  142. for idx, line in enumerate(reversed(last3)):
  143. if idx == 0:
  144. lines.append(f"- **{line}**")
  145. elif idx == 1:
  146. lines.append(f"- <span style='opacity:0.7'>{line}</span>")
  147. else:
  148. lines.append(f"- <span style='opacity:0.45'>{line}</span>")
  149. status_placeholder.markdown("<br/>".join(lines), unsafe_allow_html=True)
  150. progress = int(i * 100 / len(steps))
  151. progress_placeholder.progress(progress, text=f"Progress: {progress}%")
  152. time.sleep(delay)
  153. return True
  154. def make_demo_zip() -> bytes:
  155. """Create a demo ZIP to download as 'phantom' result."""
  156. buf = io.BytesIO()
  157. with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
  158. zf.writestr("phantom/README.txt", "Demo phantom result. Replace with real generated files.")
  159. buf.seek(0)
  160. return buf.getvalue()
  161. # ---------- Pages ----------
  162. def page_home():
  163. st.subheader("Информация о контрастах и генерации фантомов")
  164. st.write(
  165. "Здесь можно разместить описание ИП (T1/T2/PD), принципы формирования данных и примеры сканов. "
  166. "Картинки ниже подгружаются из папки `assets/`, **автоматически центр-обрезаются в квадрат** и "
  167. "**приводятся к одинаковому размеру** (320×320)."
  168. )
  169. # Load 3 prepared images from assets
  170. images = load_and_prepare_assets("assets", count=3, size=(320, 320))
  171. if images:
  172. cols = st.columns(len(images))
  173. for (img, name), col in zip(images, cols):
  174. with col:
  175. st.image(img, use_container_width=False)
  176. # st.image(img, caption=name, use_container_width=False)
  177. else:
  178. st.info("Положите 1–3 изображения в папку `assets/` (png/jpg/tif), и они появятся здесь одинакового размера.")
  179. st.markdown("---")
  180. c1, c2 = st.columns(2)
  181. with c1:
  182. with st.container(border=True):
  183. st.markdown("#### 🧠 Phantom generation")
  184. st.write("Upload T1/T2/PD images and begin the generation.")
  185. st.button("Move to the phantom", type="primary", use_container_width=True, on_click=nav_to, args=("phantom",))
  186. with c2:
  187. with st.container(border=True):
  188. st.markdown("#### 📦 Dataset generation")
  189. st.write("Generation of a dataset based on pulse sequence parameters.")
  190. st.button("Move to the dataset", type="primary", use_container_width=True, on_click=nav_to, args=("dataset",))
  191. def page_phantom():
  192. st.button("← Homepage", on_click=nav_to, args=("home",))
  193. st.subheader("Generate the phantom")
  194. st.caption("Please upload T1/T2/PD images of a brain scan")
  195. c1, c2, c3 = st.columns(3)
  196. with c1:
  197. t1_file = st.file_uploader("T1", type=SUPPORTED_EXTS_PHANTOM)
  198. with c2:
  199. t2_file = st.file_uploader("T2", type=SUPPORTED_EXTS_PHANTOM)
  200. with c3:
  201. pd_file = st.file_uploader("PD", type=SUPPORTED_EXTS_PHANTOM)
  202. start_btn = st.button("Begin generation", type="primary")
  203. progress_ph = st.empty()
  204. statuses_ph = st.empty()
  205. if start_btn:
  206. if not (t1_file and t2_file and pd_file):
  207. st.error("Нужно загрузить все три файла: T1, T2, PD.")
  208. else:
  209. done = run_job_stub(statuses_ph, progress_ph)
  210. if done:
  211. # Save demo result into session_state for download button
  212. st.session_state.phantom_blob = make_demo_zip()
  213. st.session_state.phantom_name = "phantom_demo.zip"
  214. st.success("Готово! Можно скачать результат.")
  215. # The download button appears only when phantom_blob is present (after job completes)
  216. if st.session_state.phantom_blob:
  217. st.download_button(
  218. "Download phantom",
  219. data=st.session_state.phantom_blob,
  220. file_name=st.session_state.phantom_name or "phantom_demo.zip",
  221. mime="application/zip",
  222. use_container_width=False,
  223. type="primary",
  224. )
  225. def page_dataset():
  226. st.button("← Homepage", on_click=nav_to, args=("home",))
  227. st.subheader("Generate the dataset")
  228. st.caption("Dataset generation. Please choose parameters for the pulse sequence.")
  229. c1, c2, c3 = st.columns(3)
  230. with c1:
  231. t1d_file = st.file_uploader("T1", key="t1d", type=SUPPORTED_EXTS_PHANTOM)
  232. with c2:
  233. t2d_file = st.file_uploader("T2", key="t2d", type=SUPPORTED_EXTS_PHANTOM)
  234. with c3:
  235. pdd_file = st.file_uploader("PD", key="pdd", type=SUPPORTED_EXTS_PHANTOM)
  236. count = st.number_input("Сколько примеров сгенерировать", min_value=1, max_value=MAX_VALUE_DATASET, value=50, step=1)
  237. start_ds = st.button("Начать генерацию датасета", type="primary")
  238. progress2 = st.empty()
  239. status2 = st.empty()
  240. if start_ds:
  241. if not (t1d_file and t2d_file and pdd_file):
  242. st.error("Нужно загрузить все три файла: T1, T2, PD.")
  243. else:
  244. run_job_stub(status2, progress2)
  245. st.success(f"Датасет сформирован (демо). Количество: {int(count)}")
  246. # ---------- Router ----------
  247. if st.session_state.page == "home":
  248. page_home()
  249. elif st.session_state.page == "phantom":
  250. page_phantom()
  251. elif st.session_state.page == "dataset":
  252. page_dataset()
  253. else:
  254. st.session_state.page = "home"
  255. page_home()