|
import gradio as gr |
|
import torch |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
import json |
|
|
|
import google.generativeai as genai |
|
from dotenv import load_dotenv |
|
import os |
|
import tempfile |
|
|
|
load_dotenv() |
|
|
|
genai.configure(api_key=os.getenv("GEMINI_API_KEY")) |
|
generation_config = { |
|
"temperature": 1, |
|
"top_p": 0.95, |
|
"top_k": 64, |
|
"max_output_tokens": 8192, |
|
"response_mime_type": "application/json", |
|
} |
|
gmodel = genai.GenerativeModel( |
|
model_name="gemini-1.5-flash", generation_config=generation_config |
|
) |
|
|
|
|
|
MODEL_PATH = "./" |
|
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_PATH) |
|
model = ViTForImageClassification.from_pretrained(MODEL_PATH) |
|
model.eval() |
|
|
|
def image_detection(image_file): |
|
try: |
|
image_file = genai.upload_file(path=image_file) |
|
prompt = "You are an AI model designed to identify images (of lung cells) if they are histopathological images or not, answer in 0 (not a histopathological image) or 1 (is a histopathological image), we mainly want to remove the most obviously not histopathological images. Format your response in JSON, including fields for 'is_histopathological'." |
|
response = gmodel.generate_content([prompt, image_file]).text |
|
genai.delete_file(image_file.name) |
|
res = json.loads(response) |
|
return res["is_histopathological"] |
|
except Exception as e: |
|
print(e) |
|
return "error" |
|
|
|
|
|
def classify_image(image): |
|
|
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=True, suffix=".png") as temp_image: |
|
image.save(temp_image.name) |
|
image_path = temp_image.name |
|
|
|
is_histopathological = image_detection(image_path) |
|
if is_histopathological != 1: |
|
return {"error": "Please upload a histopathological image."} |
|
|
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
results = { |
|
model.config.id2label[i]: float(prob) |
|
for i, prob in enumerate(probabilities[0]) |
|
} |
|
|
|
|
|
json_results = json.dumps(results) |
|
|
|
return json_results |
|
|
|
def launch_gradio(): |
|
interface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.Image(type="pil", label="Upload Cancer Histopathological Image"), |
|
outputs=gr.JSON(label="Classification Probabilities"), |
|
title="Lung and Colon Cancer Image Classifier", |
|
description="Upload a histopathological image to classify cancer type", |
|
) |
|
|
|
interface.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
if __name__ == "__main__": |
|
launch_gradio() |
|
|