123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- import time
- import os
- import io
- import zipfile
- from pathlib import Path
- from typing import List, Tuple
- from PIL import Image
- import streamlit as st
- import base64
- def nav_to(p: str):
- st.session_state.page = p
- # ---------- Theme-aware logo (top-right, base64-embedded) ----------
- def _b64_img(path: str) -> str | None:
- try:
- with open(path, "rb") as f:
- return base64.b64encode(f.read()).decode("ascii")
- except Exception:
- return None
- def header_with_theme_logo(title: str,
- light_path: str = "logos/NEW_PHYSTECH_for_light.png",
- dark_path: str = "logos/NEW_PHYSTECH_for_dark.png",
- size_px: int = 100):
- light_b64 = _b64_img(light_path)
- dark_b64 = _b64_img(dark_path)
- if not (light_b64 or dark_b64):
- # если нет логотипов — просто выводим заголовок
- st.markdown(f"## {title}")
- return
- light_src = f"data:image/png;base64,{light_b64}" if light_b64 else ""
- dark_src = f"data:image/png;base64,{dark_b64}" if dark_b64 else light_src
- html = f"""
- <style>
- /* убираем стандартный padding контейнера Streamlit сверху */
- section.main > div:first-child {{
- padding-top: 0rem;
- }}
- .hdr {{
- display: flex; align-items: center; justify-content: space-between;
- }}
- .hdr h1 {{
- margin: 0;
- font-size: 3.5rem;
- line-height: 1.2;
- }}
- .hdr .logo img {{
- width: 250px; height: {size_px}px; object-fit: contain;
- border-radius: 8px;
- display: inline-block;
- }}
- .hdr .logo img.light {{ display: inline-block; }}
- .hdr .logo img.dark {{ display: none; }}
- @media (prefers-color-scheme: dark) {{
- .hdr .logo img.light {{ display: none; }}
- .hdr .logo img.dark {{ display: inline-block; }}
- }}
- </style>
- <div class="hdr">
- <h1>{title}</h1>
- <div class="logo">
- <img src="{light_src}" alt="logo" class="light" />
- <img src="{dark_src}" alt="logo" class="dark" />
- </div>
- </div>
- """
- st.markdown(html, unsafe_allow_html=True)
- st.set_page_config(
- page_title="MRI physics based augmentation",
- page_icon="🧠",
- layout="wide"
- )
- header_with_theme_logo("MRI physics based augmentation")
- # ---------- Simple router in session_state ----------
- if "page" not in st.session_state:
- st.session_state.page = "home"
- # storage for generated phantom (appears after progress completes)
- if "phantom_blob" not in st.session_state:
- st.session_state.phantom_blob = None
- if "phantom_name" not in st.session_state:
- st.session_state.phantom_name = None
- # ---------- Image helpers ----------
- SUPPORTED_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")
- SUPPORTED_EXTS_PHANTOM = (".dcm", ".nii", ".nii.gz", ".nrrd", ".npy", ".png", ".jpg", ".jpeg")
- MAX_VALUE_DATASET = 100000
- def center_crop_to_square(img: Image.Image) -> Image.Image:
- """Center-crop PIL image to a square based on the smaller side."""
- w, h = img.size
- s = min(w, h)
- left = (w - s) // 2
- top = (h - s) // 2
- return img.crop((left, top, left + s, top + s))
- def load_and_prepare_assets(asset_dir: str = "assets", count: int = 3, size: Tuple[int, int] = (320, 320)) -> List[Tuple[Image.Image, str]]:
- """Load up to `count` images from asset_dir, center-crop to square, resize to `size`."""
- results = []
- if not os.path.isdir(asset_dir):
- return results
- files = sorted([f for f in os.listdir(asset_dir) if os.path.splitext(f.lower())[1] in SUPPORTED_EXTS])
- for fname in files[:count]:
- path = os.path.join(asset_dir, fname)
- try:
- img = Image.open(path).convert("RGB")
- img = center_crop_to_square(img)
- img = img.resize(size, Image.LANCZOS)
- results.append((img, fname))
- except Exception:
- continue
- return results
- def run_job_stub(status_placeholder, progress_placeholder, steps=None, delay=0.9):
- """Simulate a long-running job with progress and 3-line status stream.
- Returns True when finished."""
- if steps is None:
- steps = [
- "Инициализация пайплайна...",
- "Обработка входных данных...",
- "Генерация синтетических изображений...",
- "Постобработка результатов...",
- "Готово!",
- ]
- progress = 0
- last3 = []
- progress_placeholder.progress(progress, text="Waiting to start...")
- status_placeholder.markdown("")
- for i, msg in enumerate(steps, 1):
- last3.append(msg)
- last3 = last3[-3:] # keep only the last 3
- # Newest at the top for the "falls down" feel:
- lines = []
- for idx, line in enumerate(reversed(last3)):
- if idx == 0:
- lines.append(f"- **{line}**")
- elif idx == 1:
- lines.append(f"- <span style='opacity:0.7'>{line}</span>")
- else:
- lines.append(f"- <span style='opacity:0.45'>{line}</span>")
- status_placeholder.markdown("<br/>".join(lines), unsafe_allow_html=True)
- progress = int(i * 100 / len(steps))
- progress_placeholder.progress(progress, text=f"Progress: {progress}%")
- time.sleep(delay)
- return True
- def make_demo_zip() -> bytes:
- """Create a demo ZIP to download as 'phantom' result."""
- buf = io.BytesIO()
- with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
- zf.writestr("phantom/README.txt", "Demo phantom result. Replace with real generated files.")
- buf.seek(0)
- return buf.getvalue()
- # ---------- Pages ----------
- def page_home():
- st.subheader("Информация о контрастах и генерации фантомов")
- st.write(
- "Здесь можно разместить описание ИП (T1/T2/PD), принципы формирования данных и примеры сканов. "
- "Картинки ниже подгружаются из папки `assets/`, **автоматически центр-обрезаются в квадрат** и "
- "**приводятся к одинаковому размеру** (320×320)."
- )
- # Load 3 prepared images from assets
- images = load_and_prepare_assets("assets", count=3, size=(320, 320))
- if images:
- cols = st.columns(len(images))
- for (img, name), col in zip(images, cols):
- with col:
- st.image(img, use_container_width=False)
- # st.image(img, caption=name, use_container_width=False)
- else:
- st.info("Положите 1–3 изображения в папку `assets/` (png/jpg/tif), и они появятся здесь одинакового размера.")
- st.markdown("---")
- c1, c2 = st.columns(2)
- with c1:
- with st.container(border=True):
- st.markdown("#### 🧠 Phantom generation")
- st.write("Upload T1/T2/PD images and begin the generation.")
- st.button("Move to the phantom", type="primary", use_container_width=True, on_click=nav_to, args=("phantom",))
- with c2:
- with st.container(border=True):
- st.markdown("#### 📦 Dataset generation")
- st.write("Generation of a dataset based on pulse sequence parameters.")
- st.button("Move to the dataset", type="primary", use_container_width=True, on_click=nav_to, args=("dataset",))
- def page_phantom():
- st.button("← Homepage", on_click=nav_to, args=("home",))
- st.subheader("Generate the phantom")
- st.caption("Please upload T1/T2/PD images of a brain scan")
- c1, c2, c3 = st.columns(3)
- with c1:
- t1_file = st.file_uploader("T1", type=SUPPORTED_EXTS_PHANTOM)
- with c2:
- t2_file = st.file_uploader("T2", type=SUPPORTED_EXTS_PHANTOM)
- with c3:
- pd_file = st.file_uploader("PD", type=SUPPORTED_EXTS_PHANTOM)
- start_btn = st.button("Begin generation", type="primary")
- progress_ph = st.empty()
- statuses_ph = st.empty()
- if start_btn:
- if not (t1_file and t2_file and pd_file):
- st.error("Нужно загрузить все три файла: T1, T2, PD.")
- else:
- done = run_job_stub(statuses_ph, progress_ph)
- if done:
- # Save demo result into session_state for download button
- st.session_state.phantom_blob = make_demo_zip()
- st.session_state.phantom_name = "phantom_demo.zip"
- st.success("Готово! Можно скачать результат.")
- # The download button appears only when phantom_blob is present (after job completes)
- if st.session_state.phantom_blob:
- st.download_button(
- "Download phantom",
- data=st.session_state.phantom_blob,
- file_name=st.session_state.phantom_name or "phantom_demo.zip",
- mime="application/zip",
- use_container_width=False,
- type="primary",
- )
- def page_dataset():
- st.button("← Homepage", on_click=nav_to, args=("home",))
- st.subheader("Generate the dataset")
- st.caption("Dataset generation. Please choose parameters for the pulse sequence.")
- c1, c2, c3 = st.columns(3)
- with c1:
- t1d_file = st.file_uploader("T1", key="t1d", type=SUPPORTED_EXTS_PHANTOM)
- with c2:
- t2d_file = st.file_uploader("T2", key="t2d", type=SUPPORTED_EXTS_PHANTOM)
- with c3:
- pdd_file = st.file_uploader("PD", key="pdd", type=SUPPORTED_EXTS_PHANTOM)
- count = st.number_input("Сколько примеров сгенерировать", min_value=1, max_value=MAX_VALUE_DATASET, value=50, step=1)
- start_ds = st.button("Начать генерацию датасета", type="primary")
- progress2 = st.empty()
- status2 = st.empty()
- if start_ds:
- if not (t1d_file and t2d_file and pdd_file):
- st.error("Нужно загрузить все три файла: T1, T2, PD.")
- else:
- run_job_stub(status2, progress2)
- st.success(f"Датасет сформирован (демо). Количество: {int(count)}")
- # ---------- Router ----------
- if st.session_state.page == "home":
- page_home()
- elif st.session_state.page == "phantom":
- page_phantom()
- elif st.session_state.page == "dataset":
- page_dataset()
- else:
- st.session_state.page = "home"
- page_home()
|