lungcare / main.py
rjAnupam's picture
first
98cbf16
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
import json
# model path
MODEL_PATH = "./"
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_PATH)
model = ViTForImageClassification.from_pretrained(MODEL_PATH)
model.eval()
# image
def classify_image(image):
# convert to RGB
if image.mode != "RGB":
image = image.convert("RGB")
# preprocess
inputs = feature_extractor(images=image, return_tensors="pt")
# get prediction
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# results
results = {
model.config.id2label[i]: float(prob)
for i, prob in enumerate(probabilities[0])
}
# Convert results to JSON
json_results = json.dumps(results)
return json_results
def launch_gradio():
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Cancer Histopathological Image"),
outputs=gr.JSON(label="Classification Probabilities"),
title="Lung and Colon Cancer Image Classifier",
description="Upload a histopathological image to classify cancer type",
)
interface.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
launch_gradio()