Source code for cyto_ml.visualisation.app
"""
Streamlit application to visualise how plankton cluster
based on their embeddings from a deep learning model
* Metadata in intake catalogue (basically a dataframe of filenames
- later this could have lon/lat, date, depth read from Exif headers
* Embeddings in chromadb, linked by filename
"""
import logging
import os
import random
from io import BytesIO
from typing import List, Optional
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
import streamlit as st
from dotenv import load_dotenv
from PIL import Image
from cyto_ml.data.db_config import OPTIONS
from cyto_ml.data.image import normalise_flowlr
from cyto_ml.data.vectorstore import vector_store
from cyto_ml.visualisation.config import COLLECTIONS
logging.basicConfig(level=logging.INFO)
load_dotenv()
STORE_TYPE = "sqlite"
[docs]
def collections() -> List[str]:
# TODO improve this when switching from chroma to different db backends
return COLLECTIONS
[docs]
@st.cache_resource
def store(coll: str) -> None:
"""
Load the vector store with image embeddings.
"""
# TODO stop recreating the connection on every call
# E.g. chroma will have one store per collection...
db_name = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../../data", f"{coll}.db")
return vector_store(STORE_TYPE, db_name, **OPTIONS[STORE_TYPE])
[docs]
@st.cache_data
def image_ids(coll: str) -> list:
"""
Retrieve image embeddings from chroma database.
TODO Revisit our available metadata
"""
return store(coll).ids()
[docs]
@st.cache_data
def image_embeddings() -> list:
return store(st.session_state["collection"]).embeddings()
[docs]
def closest_n(url: str, n: Optional[int] = 26) -> list:
"""
Given an image URL return the N closest ones by cosine distance
"""
s = store(st.session_state["collection"])
results = s.closest(url, n_results=n)
# logging.info(results)
return results
[docs]
@st.cache_data
def cached_image(url: str) -> Image:
"""
Read an image URL from s3 and return a PIL Image
Hopefully caches this per-image, so it'll speed up
We tried streamlit_clickable_images but no tiff support
"""
response = requests.get(url)
image = Image.open(BytesIO(response.content))
# Special handling for Flow Cytometer images,
# All 16 bit greyscale in low range of values
if image.mode == "I;16":
image = normalise_flowlr(image)
return image
[docs]
def closest_grid(size: Optional[int] = 65) -> None:
"""
Given an image URL, render a grid of the N nearest images
by cosine distance between embeddings
N defaults to 26
"""
start_url = st.session_state["start_img"]
closest = closest_n(start_url, size)
# TODO understand where layout should happen
rows = []
for _ in range(0, 8):
rows.append(st.columns(8))
for index, _ in enumerate(rows):
for c in rows[index]:
try:
source_image, distance = closest.pop()
except IndexError:
break
next_image = source_image.replace(".tif", ".png")
next_image = next_image.replace("untagged-images-lana", "untagged-images-lana-ls")
c.image(next_image, width=60)
c.button("this", key=source_image, on_click=pick_image, args=[source_image])
[docs]
def random_image() -> str:
ids = image_ids(st.session_state["collection"])
# starting image
test_image_url = random.choice(ids)
return test_image_url
[docs]
def pick_image(image: str) -> None:
st.session_state["start_img"] = image
[docs]
def show_random_image() -> None:
logging.debug("show" + st.session_state["start_img"])
if st.session_state["start_img"]:
st.image(cached_image(st.session_state["start_img"]))
st.write(st.session_state["start_img"])
[docs]
def main() -> None:
"""
Main method that sets up the streamlit app and builds the visualisation.
"""
colls = collections()
if "collection" not in st.session_state:
st.session_state["collection"] = colls[0]
st.set_page_config(layout="wide", page_title="Plankton image embeddings")
st.title("Image embeddings")
st.write(f"{len(image_ids(st.session_state['collection']))} images in {st.session_state['collection']}")
# the generated HTML is not lovely at all
st.selectbox(
"image collection",
colls,
key="collection",
)
if "start_img" not in st.session_state or st.session_state["start_img"] is None:
st.session_state["start_img"] = random_image()
show_random_image()
st.text("<-- random image")
st.button("try again", on_click=random_image)
# TODO figure out how streamlit is supposed to work
closest_grid()
if __name__ == "__main__":
main()