import gradio as gr import torch from transformers import ViTFeatureExtractor, ViTForImageClassification import json import google.generativeai as genai from dotenv import load_dotenv import os import tempfile load_dotenv() genai.configure(api_key=os.getenv("GEMINI_API_KEY")) generation_config = { "temperature": 1, "top_p": 0.95, "top_k": 64, "max_output_tokens": 8192, "response_mime_type": "application/json", } gmodel = genai.GenerativeModel( model_name="gemini-1.5-flash", generation_config=generation_config ) # model path MODEL_PATH = "./" feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_PATH) model = ViTForImageClassification.from_pretrained(MODEL_PATH) model.eval() def image_detection(image_file): try: image_file = genai.upload_file(path=image_file) prompt = "You are an AI model designed to identify images (of lung cells) if they are histopathological images or not, answer in 0 (not a histopathological image) or 1 (is a histopathological image), we mainly want to remove the most obviously not histopathological images. Format your response in JSON, including fields for 'is_histopathological'." response = gmodel.generate_content([prompt, image_file]).text genai.delete_file(image_file.name) res = json.loads(response) return res["is_histopathological"] except Exception as e: print(e) return "error" # image def classify_image(image): # convert to RGB if image.mode != "RGB": image = image.convert("RGB") # check if image is histopathological with tempfile.NamedTemporaryFile(delete=True, suffix=".png") as temp_image: image.save(temp_image.name) image_path = temp_image.name is_histopathological = image_detection(image_path) if is_histopathological != 1: return {"error": "Please upload a histopathological image."} # preprocess inputs = feature_extractor(images=image, return_tensors="pt") # get prediction with torch.no_grad(): outputs = model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) # results results = { model.config.id2label[i]: float(prob) for i, prob in enumerate(probabilities[0]) } # Convert results to JSON json_results = json.dumps(results) return json_results def launch_gradio(): interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Cancer Histopathological Image"), outputs=gr.JSON(label="Classification Probabilities"), title="Lung and Colon Cancer Image Classifier", description="Upload a histopathological image to classify cancer type", ) interface.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": launch_gradio()