Source code for cyto_ml.models.utils
import torch
import torchvision
# Definitions are from here
# https://github.com/alan-turing-institute/ViT-LASNet/blob/main/test/test.py
# TODO keep these elsewhere than `utils`
# TODO consider adding the transformer model, focus on the 3 class resnet18 for now
[docs]
def resnet18(num_classes: int, filename: str = "", strip_final_layer: bool = False) -> torchvision.Module:
model = torchvision.models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model_state_dict = torch.load(filename, map_location="cpu")
model.load_state_dict(model_state_dict)
# Return embeddings rather than the labels
if strip_final_layer:
model.fc = torch.nn.Identity()
model.eval()
return model
[docs]
def flat_embeddings(features: torch.Tensor) -> list:
"""Utility function that takes the features returned by the model in truncate_model
And flattens them into a list suitable for storing in a vector database"""
# TODO: this only returns the 0th tensor in the batch...why?
return features[0].detach().tolist()