app.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import time
  2. import os
  3. import io
  4. import zipfile
  5. from pathlib import Path
  6. from typing import List, Tuple
  7. from PIL import Image
  8. import streamlit as st
  9. import base64
  10. def nav_to(p: str):
  11. st.session_state.page = p
  12. # ---------- Theme-aware logo (top-right, base64-embedded) ----------
  13. def _b64_img(path: str) -> str | None:
  14. try:
  15. with open(path, "rb") as f:
  16. return base64.b64encode(f.read()).decode("ascii")
  17. except Exception:
  18. return None
  19. def header_with_theme_logo(title: str,
  20. light_path: str = "logos/NEW_PHYSTECH_for_light.png",
  21. dark_path: str = "logos/NEW_PHYSTECH_for_dark.png",
  22. size_px: int = 100):
  23. light_b64 = _b64_img(light_path)
  24. dark_b64 = _b64_img(dark_path)
  25. if not (light_b64 or dark_b64):
  26. # если нет логотипов — просто выводим заголовок
  27. st.markdown(f"## {title}")
  28. return
  29. light_src = f"data:image/png;base64,{light_b64}" if light_b64 else ""
  30. dark_src = f"data:image/png;base64,{dark_b64}" if dark_b64 else light_src
  31. html = f"""
  32. <style>
  33. /* убираем стандартный padding контейнера Streamlit сверху */
  34. section.main > div:first-child {{
  35. padding-top: 0rem;
  36. }}
  37. .hdr {{
  38. display: flex; align-items: center; justify-content: space-between;
  39. }}
  40. .hdr h1 {{
  41. margin: 0;
  42. font-size: 3.5rem;
  43. line-height: 1.2;
  44. }}
  45. .hdr .logo img {{
  46. width: 250px; height: {size_px}px; object-fit: contain;
  47. border-radius: 8px;
  48. display: inline-block;
  49. }}
  50. .hdr .logo img.light {{ display: inline-block; }}
  51. .hdr .logo img.dark {{ display: none; }}
  52. @media (prefers-color-scheme: dark) {{
  53. .hdr .logo img.light {{ display: none; }}
  54. .hdr .logo img.dark {{ display: inline-block; }}
  55. }}
  56. </style>
  57. <div class="hdr">
  58. <h1>{title}</h1>
  59. <div class="logo">
  60. <img src="{light_src}" alt="logo" class="light" />
  61. <img src="{dark_src}" alt="logo" class="dark" />
  62. </div>
  63. </div>
  64. """
  65. st.markdown(html, unsafe_allow_html=True)
  66. st.set_page_config(
  67. page_title="MRI physics based augmentation",
  68. page_icon="🧠",
  69. layout="wide"
  70. )
  71. header_with_theme_logo("MRI physics based augmentation")
  72. # ---------- Simple router in session_state ----------
  73. if "page" not in st.session_state:
  74. st.session_state.page = "home"
  75. # storage for generated phantom (appears after progress completes)
  76. if "phantom_blob" not in st.session_state:
  77. st.session_state.phantom_blob = None
  78. if "phantom_name" not in st.session_state:
  79. st.session_state.phantom_name = None
  80. # ---------- Image helpers ----------
  81. SUPPORTED_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")
  82. SUPPORTED_EXTS_PHANTOM = (".dcm", ".nii", ".nii.gz", ".nrrd", ".npy", ".png", ".jpg", ".jpeg")
  83. MAX_VALUE_DATASET = 100000
  84. def center_crop_to_square(img: Image.Image) -> Image.Image:
  85. """Center-crop PIL image to a square based on the smaller side."""
  86. w, h = img.size
  87. s = min(w, h)
  88. left = (w - s) // 2
  89. top = (h - s) // 2
  90. return img.crop((left, top, left + s, top + s))
  91. def load_and_prepare_assets(asset_dir: str = "assets", count: int = 3, size: Tuple[int, int] = (320, 320)) -> List[Tuple[Image.Image, str]]:
  92. """Load up to `count` images from asset_dir, center-crop to square, resize to `size`."""
  93. results = []
  94. if not os.path.isdir(asset_dir):
  95. return results
  96. files = sorted([f for f in os.listdir(asset_dir) if os.path.splitext(f.lower())[1] in SUPPORTED_EXTS])
  97. for fname in files[:count]:
  98. path = os.path.join(asset_dir, fname)
  99. try:
  100. img = Image.open(path).convert("RGB")
  101. img = center_crop_to_square(img)
  102. img = img.resize(size, Image.LANCZOS)
  103. results.append((img, fname))
  104. except Exception:
  105. continue
  106. return results
  107. def run_job_stub(status_placeholder, progress_placeholder, steps=None, delay=0.9):
  108. """Simulate a long-running job with progress and 3-line status stream.
  109. Returns True when finished."""
  110. if steps is None:
  111. steps = [
  112. "Инициализация пайплайна...",
  113. "Обработка входных данных...",
  114. "Генерация синтетических изображений...",
  115. "Постобработка результатов...",
  116. "Готово!",
  117. ]
  118. progress = 0
  119. last3 = []
  120. progress_placeholder.progress(progress, text="Waiting to start...")
  121. status_placeholder.markdown("")
  122. for i, msg in enumerate(steps, 1):
  123. last3.append(msg)
  124. last3 = last3[-3:] # keep only the last 3
  125. # Newest at the top for the "falls down" feel:
  126. lines = []
  127. for idx, line in enumerate(reversed(last3)):
  128. if idx == 0:
  129. lines.append(f"- **{line}**")
  130. elif idx == 1:
  131. lines.append(f"- <span style='opacity:0.7'>{line}</span>")
  132. else:
  133. lines.append(f"- <span style='opacity:0.45'>{line}</span>")
  134. status_placeholder.markdown("<br/>".join(lines), unsafe_allow_html=True)
  135. progress = int(i * 100 / len(steps))
  136. progress_placeholder.progress(progress, text=f"Progress: {progress}%")
  137. time.sleep(delay)
  138. return True
  139. def make_demo_zip() -> bytes:
  140. """Create a demo ZIP to download as 'phantom' result."""
  141. buf = io.BytesIO()
  142. with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
  143. zf.writestr("phantom/README.txt", "Demo phantom result. Replace with real generated files.")
  144. buf.seek(0)
  145. return buf.getvalue()
  146. # ---------- Pages ----------
  147. def page_home():
  148. st.subheader("Информация о контрастах и генерации фантомов")
  149. st.write(
  150. "Здесь можно разместить описание ИП (T1/T2/PD), принципы формирования данных и примеры сканов. "
  151. "Картинки ниже подгружаются из папки `assets/`, **автоматически центр-обрезаются в квадрат** и "
  152. "**приводятся к одинаковому размеру** (320×320)."
  153. )
  154. # Load 3 prepared images from assets
  155. images = load_and_prepare_assets("assets", count=3, size=(320, 320))
  156. if images:
  157. cols = st.columns(len(images))
  158. for (img, name), col in zip(images, cols):
  159. with col:
  160. st.image(img, use_container_width=False)
  161. # st.image(img, caption=name, use_container_width=False)
  162. else:
  163. st.info("Положите 1–3 изображения в папку `assets/` (png/jpg/tif), и они появятся здесь одинакового размера.")
  164. st.markdown("---")
  165. c1, c2 = st.columns(2)
  166. with c1:
  167. with st.container(border=True):
  168. st.markdown("#### 🧠 Phantom generation")
  169. st.write("Upload T1/T2/PD images and begin the generation.")
  170. st.button("Move to the phantom", type="primary", use_container_width=True, on_click=nav_to, args=("phantom",))
  171. with c2:
  172. with st.container(border=True):
  173. st.markdown("#### 📦 Dataset generation")
  174. st.write("Generation of a dataset based on pulse sequence parameters.")
  175. st.button("Move to the dataset", type="primary", use_container_width=True, on_click=nav_to, args=("dataset",))
  176. def page_phantom():
  177. st.button("← Homepage", on_click=nav_to, args=("home",))
  178. st.subheader("Generate the phantom")
  179. st.caption("Please upload T1/T2/PD images of a brain scan")
  180. c1, c2, c3 = st.columns(3)
  181. with c1:
  182. t1_file = st.file_uploader("T1", type=SUPPORTED_EXTS_PHANTOM)
  183. with c2:
  184. t2_file = st.file_uploader("T2", type=SUPPORTED_EXTS_PHANTOM)
  185. with c3:
  186. pd_file = st.file_uploader("PD", type=SUPPORTED_EXTS_PHANTOM)
  187. start_btn = st.button("Begin generation", type="primary")
  188. progress_ph = st.empty()
  189. statuses_ph = st.empty()
  190. if start_btn:
  191. if not (t1_file and t2_file and pd_file):
  192. st.error("Нужно загрузить все три файла: T1, T2, PD.")
  193. else:
  194. done = run_job_stub(statuses_ph, progress_ph)
  195. if done:
  196. # Save demo result into session_state for download button
  197. st.session_state.phantom_blob = make_demo_zip()
  198. st.session_state.phantom_name = "phantom_demo.zip"
  199. st.success("Готово! Можно скачать результат.")
  200. # The download button appears only when phantom_blob is present (after job completes)
  201. if st.session_state.phantom_blob:
  202. st.download_button(
  203. "Download phantom",
  204. data=st.session_state.phantom_blob,
  205. file_name=st.session_state.phantom_name or "phantom_demo.zip",
  206. mime="application/zip",
  207. use_container_width=False,
  208. type="primary",
  209. )
  210. def page_dataset():
  211. st.button("← Homepage", on_click=nav_to, args=("home",))
  212. st.subheader("Generate the dataset")
  213. st.caption("Dataset generation. Please choose parameters for the pulse sequence.")
  214. c1, c2, c3 = st.columns(3)
  215. with c1:
  216. t1d_file = st.file_uploader("T1", key="t1d", type=SUPPORTED_EXTS_PHANTOM)
  217. with c2:
  218. t2d_file = st.file_uploader("T2", key="t2d", type=SUPPORTED_EXTS_PHANTOM)
  219. with c3:
  220. pdd_file = st.file_uploader("PD", key="pdd", type=SUPPORTED_EXTS_PHANTOM)
  221. count = st.number_input("Сколько примеров сгенерировать", min_value=1, max_value=MAX_VALUE_DATASET, value=50, step=1)
  222. start_ds = st.button("Начать генерацию датасета", type="primary")
  223. progress2 = st.empty()
  224. status2 = st.empty()
  225. if start_ds:
  226. if not (t1d_file and t2d_file and pdd_file):
  227. st.error("Нужно загрузить все три файла: T1, T2, PD.")
  228. else:
  229. run_job_stub(status2, progress2)
  230. st.success(f"Датасет сформирован (демо). Количество: {int(count)}")
  231. # ---------- Router ----------
  232. if st.session_state.page == "home":
  233. page_home()
  234. elif st.session_state.page == "phantom":
  235. page_phantom()
  236. elif st.session_state.page == "dataset":
  237. page_dataset()
  238. else:
  239. st.session_state.page = "home"
  240. page_home()