Source code for cyto_ml.models.api

# Experimental API for serving a range of plankton models
# * Choose model endpoint
# * POST image contents
# * Probably parameter for a normalisation function with a sensible default
# * Get back a dict with classification
# * Option for confidence levels (if our models are calibrated)
# * Option to also return embeddings (could be enabled by default)
import logging
import os

import torch
import uvicorn
from fastapi import FastAPI, Form
from fastapi.responses import JSONResponse
from resnet50_cefas import load_model

from cyto_ml.data.image import load_image_from_url
from cyto_ml.data.labels import RESNET18_LABELS
from cyto_ml.models.utils import flat_embeddings, resnet18

STATE_FILE = "../../../data/weights/ResNet_18_3classes_RGB.pth"

# 3-class ResNet18, newer work from Turing Inst
# https://noushineftekhari.github.io/publication/2024-marine-plankton-classification
resnet18_classifier = None
resnet18_embeddings = None

# Fork of earlier Turing model via sci.vision
# https://github.com/ukceh-rse/resnet50-cefas
resnet50_model = None

app = FastAPI()

# TODO look at ProcessPoolExecutor with load function for concurrent requests
# https://luis-sena.medium.com/how-to-optimize-fastapi-for-ml-model-serving-6f75fb9e040d
# The load_models function here is made with that in mind.


# TODO pass state in a reproducible way - weights are on Google Drive
# This could easily become overkill if we start adding ViT, BioCLIP etc
[docs] def load_models() -> None: global resnet50_model # noqa PLW0603 resnet50_model = load_model(strip_final_layer=True) state_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), STATE_FILE) # We expect some conditions (like the tests) not to have these weights try: global resnet18_classifier # noqa PLW0603 resnet18_classifier = resnet18(num_classes=3, filename=state_file) global resnet18_embeddings # noqa PLW0603 resnet18_embeddings = resnet18(num_classes=3, filename=state_file, strip_final_layer=True) except FileNotFoundError as err: logging.warning(err) except Exception as err: raise (err)
load_models()
[docs] @app.get("/") async def root() -> JSONResponse: return {"message": "Hello World"}
# interfaces for each of the models
[docs] @app.post("/resnet50/") async def resnet50(url: str = Form(...)) -> JSONResponse: # strip_final_layer is only if we want embeddings features = resnet50_model(load_image_from_url(url)) embeddings = flat_embeddings(features) return {"embeddings": embeddings}
[docs] @app.post("/resnet18/") async def resnet18_3(url: str = Form(...)) -> JSONResponse: """Use the 3 class Resnet18 model to return both a prediction and a set of image embeddings""" if not resnet18_classifier: return JSONResponse(status_code=404, content={"error": "Model not found"}) image = load_image_from_url(url) # TODO look at the normalisation / resize functions in Vit-lasnet tests, use them? outputs = resnet18_classifier(image) _, predicted = torch.max(outputs, 1) outputs = resnet18_embeddings(image) embeddings = flat_embeddings(outputs) return {"classification": RESNET18_LABELS[predicted], "embeddings": embeddings}
if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8081)