import gradio as gr import torch from transformers import ViTFeatureExtractor, ViTForImageClassification import json # model path MODEL_PATH = "./" feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_PATH) model = ViTForImageClassification.from_pretrained(MODEL_PATH) model.eval() # image def classify_image(image): # convert to RGB if image.mode != "RGB": image = image.convert("RGB") # preprocess inputs = feature_extractor(images=image, return_tensors="pt") # get prediction with torch.no_grad(): outputs = model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) # results results = { model.config.id2label[i]: float(prob) for i, prob in enumerate(probabilities[0]) } # Convert results to JSON json_results = json.dumps(results) return json_results def launch_gradio(): interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Cancer Histopathological Image"), outputs=gr.JSON(label="Classification Probabilities"), title="Lung and Colon Cancer Image Classifier", description="Upload a histopathological image to classify cancer type", ) interface.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": launch_gradio()