Source code for cyto_ml.models.api
# Experimental API for serving a range of plankton models
# * Choose model endpoint
# * POST image contents
# * Probably parameter for a normalisation function with a sensible default
# * Get back a dict with classification
# * Option for confidence levels (if our models are calibrated)
# * Option to also return embeddings (could be enabled by default)
import logging
import os
import torch
import uvicorn
from fastapi import FastAPI, Form
from fastapi.responses import JSONResponse
from resnet50_cefas import load_model
from cyto_ml.data.image import load_image_from_url
from cyto_ml.data.labels import RESNET18_LABELS
from cyto_ml.models.utils import flat_embeddings, resnet18
STATE_FILE = "../../../data/weights/ResNet_18_3classes_RGB.pth"
# 3-class ResNet18, newer work from Turing Inst
# https://noushineftekhari.github.io/publication/2024-marine-plankton-classification
resnet18_classifier = None
resnet18_embeddings = None
# Fork of earlier Turing model via sci.vision
# https://github.com/ukceh-rse/resnet50-cefas
resnet50_model = None
app = FastAPI()
# TODO look at ProcessPoolExecutor with load function for concurrent requests
# https://luis-sena.medium.com/how-to-optimize-fastapi-for-ml-model-serving-6f75fb9e040d
# The load_models function here is made with that in mind.
# TODO pass state in a reproducible way - weights are on Google Drive
# This could easily become overkill if we start adding ViT, BioCLIP etc
[docs]
def load_models() -> None:
global resnet50_model # noqa PLW0603
resnet50_model = load_model(strip_final_layer=True)
state_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), STATE_FILE)
# We expect some conditions (like the tests) not to have these weights
try:
global resnet18_classifier # noqa PLW0603
resnet18_classifier = resnet18(num_classes=3, filename=state_file)
global resnet18_embeddings # noqa PLW0603
resnet18_embeddings = resnet18(num_classes=3, filename=state_file, strip_final_layer=True)
except FileNotFoundError as err:
logging.warning(err)
except Exception as err:
raise (err)
load_models()
[docs]
@app.get("/")
async def root() -> JSONResponse:
return {"message": "Hello World"}
# interfaces for each of the models
[docs]
@app.post("/resnet50/")
async def resnet50(url: str = Form(...)) -> JSONResponse:
# strip_final_layer is only if we want embeddings
features = resnet50_model(load_image_from_url(url))
embeddings = flat_embeddings(features)
return {"embeddings": embeddings}
[docs]
@app.post("/resnet18/")
async def resnet18_3(url: str = Form(...)) -> JSONResponse:
"""Use the 3 class Resnet18 model to return both a prediction
and a set of image embeddings"""
if not resnet18_classifier:
return JSONResponse(status_code=404, content={"error": "Model not found"})
image = load_image_from_url(url)
# TODO look at the normalisation / resize functions in Vit-lasnet tests, use them?
outputs = resnet18_classifier(image)
_, predicted = torch.max(outputs, 1)
outputs = resnet18_embeddings(image)
embeddings = flat_embeddings(outputs)
return {"classification": RESNET18_LABELS[predicted], "embeddings": embeddings}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8081)