eyegazer-demo / app.py
Nick Doiron
gradio fixes
ae753cc
raw
history blame
No virus
2.31 kB
import gradio as gr
import os
from peft import PeftModel
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
model_name = 'google/vit-large-patch16-224'
adapter = 'monsoon-nlp/eyegazer-vit-binary'
image_processor = AutoImageProcessor.from_pretrained(model_name)
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(image_processor.size["height"]),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
val_transforms = Compose(
[
Resize(image_processor.size["height"]),
CenterCrop(image_processor.size["height"]),
ToTensor(),
normalize,
]
)
model = AutoModelForImageClassification.from_pretrained(
model_name,
ignore_mismatched_sizes=True,
num_labels=2,
)
lora_model = PeftModel.from_pretrained(model, adapter)
def query(img):
pimg = val_transforms(img.convert("RGB"))
batch = pimg.unsqueeze(0)
op = lora_model(batch)
vals = op.logits.tolist()[0]
if vals[0] > vals[1]:
return "Predicted unaffected"
else:
return "Predicted affected to some degree"
iface = gr.Interface(
fn=query,
examples=[
os.path.join(os.path.dirname(__file__), "images/i1.png"),
os.path.join(os.path.dirname(__file__), "images/0a09aa7356c0.png"),
os.path.join(os.path.dirname(__file__), "images/0a4e1a29ffff.png"),
os.path.join(os.path.dirname(__file__), "images/0c43c79e8cfb.png"),
os.path.join(os.path.dirname(__file__), "images/0c7e82daf5a0.png"),
],
inputs=[
gr.Image(
image_mode='RGB',
sources=['upload', 'clipboard'],
type='pil',
label='Input Fundus Camera Image',
show_label=True,
),
],
outputs=[
gr.Markdown(value="", label="Predicted label"),
],
title="ViT retinopathy model",
description="Diabetic retinopathy model trained on APTOS 2019 dataset; demonstration, not medical dvice",
allow_flagging="never",
)
iface.launch()