Source code for cyto_ml.models.api
# Experimental API for serving a range of plankton models
# * Choose model endpoint
# * POST image contents
# * Probably parameter for a normalisation function with a sensible default
# * Get back a dict with classification
# * Option for confidence levels (if our models are calibrated)
# * Option to also return embeddings (could be enabled by default)
import logging
import os
import torch
import uvicorn
from fastapi import FastAPI, Form
from fastapi.responses import JSONResponse
from resnet50_cefas import load_model
from import load_image_from_url
from import RESNET18_LABELS
from cyto_ml.models.utils import flat_embeddings, resnet18
STATE_FILE = "../../../data/weights/ResNet_18_3classes_RGB.pth"
# 3-class ResNet18, newer work from Turing Inst
resnet18_classifier = None
resnet18_embeddings = None
# Fork of earlier Turing model via
resnet50_model = None
app = FastAPI()
# TODO look at ProcessPoolExecutor with load function for concurrent requests
# The load_models function here is made with that in mind.
# TODO pass state in a reproducible way - weights are on Google Drive
# This could easily become overkill if we start adding ViT, BioCLIP etc
def load_models() -> None:
global resnet50_model # noqa PLW0603
resnet50_model = load_model(strip_final_layer=True)
state_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), STATE_FILE)
# We expect some conditions (like the tests) not to have these weights
global resnet18_classifier # noqa PLW0603
resnet18_classifier = resnet18(num_classes=3, filename=state_file)
global resnet18_embeddings # noqa PLW0603
resnet18_embeddings = resnet18(num_classes=3, filename=state_file, strip_final_layer=True)
except FileNotFoundError as err:
except Exception as err:
raise (err)
async def root() -> JSONResponse:
return {"message": "Hello World"}
# interfaces for each of the models
async def resnet50(url: str = Form(...)) -> JSONResponse:
# strip_final_layer is only if we want embeddings
features = resnet50_model(load_image_from_url(url))
embeddings = flat_embeddings(features)
return {"embeddings": embeddings}
async def resnet18_3(url: str = Form(...)) -> JSONResponse:
"""Use the 3 class Resnet18 model to return both a prediction
and a set of image embeddings"""
if not resnet18_classifier:
return JSONResponse(status_code=404, content={"error": "Model not found"})
image = load_image_from_url(url)
# TODO look at the normalisation / resize functions in Vit-lasnet tests, use them?
outputs = resnet18_classifier(image)
_, predicted = torch.max(outputs, 1)
outputs = resnet18_embeddings(image)
embeddings = flat_embeddings(outputs)
return {"classification": RESNET18_LABELS[predicted], "embeddings": embeddings}
if __name__ == "__main__":, host="", port=8081)