lungcare / app.py
rjAnupam's picture
ok
bb31844
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
import json
import google.generativeai as genai
from dotenv import load_dotenv
import os
import tempfile
load_dotenv()
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
generation_config = {
"temperature": 1,
"top_p": 0.95,
"top_k": 64,
"max_output_tokens": 8192,
"response_mime_type": "application/json",
}
gmodel = genai.GenerativeModel(
model_name="gemini-1.5-flash", generation_config=generation_config
)
# model path
MODEL_PATH = "./"
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_PATH)
model = ViTForImageClassification.from_pretrained(MODEL_PATH)
model.eval()
def image_detection(image_file):
try:
image_file = genai.upload_file(path=image_file)
prompt = "You are an AI model designed to identify images (of lung cells) if they are histopathological images or not, answer in 0 (not a histopathological image) or 1 (is a histopathological image), we mainly want to remove the most obviously not histopathological images. Format your response in JSON, including fields for 'is_histopathological'."
response = gmodel.generate_content([prompt, image_file]).text
genai.delete_file(image_file.name)
res = json.loads(response)
return res["is_histopathological"]
except Exception as e:
print(e)
return "error"
# image
def classify_image(image):
# convert to RGB
if image.mode != "RGB":
image = image.convert("RGB")
# check if image is histopathological
with tempfile.NamedTemporaryFile(delete=True, suffix=".png") as temp_image:
image.save(temp_image.name)
image_path = temp_image.name
is_histopathological = image_detection(image_path)
if is_histopathological != 1:
return {"error": "Please upload a histopathological image."}
# preprocess
inputs = feature_extractor(images=image, return_tensors="pt")
# get prediction
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# results
results = {
model.config.id2label[i]: float(prob)
for i, prob in enumerate(probabilities[0])
}
# Convert results to JSON
json_results = json.dumps(results)
return json_results
def launch_gradio():
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Cancer Histopathological Image"),
outputs=gr.JSON(label="Classification Probabilities"),
title="Lung and Colon Cancer Image Classifier",
description="Upload a histopathological image to classify cancer type",
)
interface.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
launch_gradio()