import io import timm import torch import streamlit as st from PIL import Image from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform class ImageClassifier(object): def __init__(self, model, labels): self.model = model self.labels = labels def get_top_5_predictions(self, image): values, indices = torch.topk(self.get_output_probabilities(image), 5) return [ {'label': self.labels[i], 'score': v.item()} for i, v in zip(indices, values) ] def get_output_probabilities(self, image): output = self.classify_image(image) return torch.nn.functional.softmax(output[0], dim=0) def classify_image(self, image): self.model.eval() transform = self.create_image_transform() return self.model(transform(image).unsqueeze(0)) def create_image_transform(self): return create_transform(**resolve_data_config( self.model.pretrained_cfg, model=self.model)) class ImageClassificationApp(object): def __init__(self, title, classifier): self.title = title self.classifier = classifier def render(self): st.title(self.title) uploaded_image = self.get_uploaded_image() if uploaded_image is not None: self.show_image_and_results(uploaded_image) def get_uploaded_image(self): return st.file_uploader('Choose an image...', type=['jpg', 'png', 'jpeg']) def show_image_and_results(self, uploaded_image): self.show_uploaded_image(uploaded_image) self.show_classification_results(self.get_image(uploaded_image.read())) def show_uploaded_image(self, uploaded_image): st.image(uploaded_image, caption='Uploaded Image', use_column_width=True) def show_classification_results(self, image): st.subheader('Classification Results:') self.write_top_5_predictions(image) def write_top_5_predictions(self, image): for prediction in self.classifier.get_top_5_predictions(image): st.write(f"- {prediction['label']}: {prediction['score']:.4f}") def get_image(self, image_data): return Image.open(io.BytesIO(image_data)) if __name__ == '__main__': model = timm.create_model( 'hf-hub:nateraw/resnet50-oxford-iiit-pet', pretrained=True ) labels = model.pretrained_cfg['label_names'] classifier = ImageClassifier(model, labels) ImageClassificationApp( 'Pet Image Classification App', classifier ).render()