|
import gradio as gr |
|
import torch |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
import json |
|
|
|
|
|
MODEL_PATH = "./" |
|
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_PATH) |
|
model = ViTForImageClassification.from_pretrained(MODEL_PATH) |
|
model.eval() |
|
|
|
|
|
def classify_image(image): |
|
|
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
results = { |
|
model.config.id2label[i]: float(prob) |
|
for i, prob in enumerate(probabilities[0]) |
|
} |
|
|
|
|
|
json_results = json.dumps(results) |
|
|
|
return json_results |
|
|
|
def launch_gradio(): |
|
interface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.Image(type="pil", label="Upload Cancer Histopathological Image"), |
|
outputs=gr.JSON(label="Classification Probabilities"), |
|
title="Lung and Colon Cancer Image Classifier", |
|
description="Upload a histopathological image to classify cancer type", |
|
) |
|
|
|
interface.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
if __name__ == "__main__": |
|
launch_gradio() |
|
|