Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
# Load the model and feature extractor from Hugging Face | |
repository_id = "Hammad712/brainmri-vit-model" | |
model = ViTForImageClassification.from_pretrained(repository_id) | |
feature_extractor = ViTImageProcessor.from_pretrained(repository_id) | |
# Function to perform inference | |
def predict(image): | |
# Convert image to RGB and preprocess it | |
image = image.convert("RGB") | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
# Move the inputs to the appropriate device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the predicted label | |
logits = outputs.logits | |
predicted_label = logits.argmax(-1).item() | |
# Map the label to "No" or "Yes" | |
label_map = {0: "No", 1: "Yes"} | |
return label_map[predicted_label] | |
# Streamlit app | |
st.title("Brain MRI Tumor Detection") | |
st.write("Upload an MRI image to predict whether it contains a tumor.") | |
# File uploader | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Display the uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Perform inference and display the result | |
st.write("Classifying...") | |
label = predict(image) | |
st.write(f"Predicted label: {label}") | |