Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
# Load the trained model | |
MODEL_PATH = "resnet_model.pth" # Update with your actual model path | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = torch.load(MODEL_PATH, map_location=device) | |
model.eval() | |
# Define the image transformation pipeline | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Streamlit UI | |
st.title("Saliva Disease Detection App") | |
st.subheader("Predict Streptococcosis vs NOT Streptococcosis from uploaded saliva images") | |
# Initialize session state for managing the uploaded file | |
if "uploaded_file" not in st.session_state: | |
st.session_state["uploaded_file"] = None | |
# File uploader | |
uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"], key="file_uploader") | |
if uploaded_file is not None: | |
st.session_state["uploaded_file"] = uploaded_file | |
# If a file has been uploaded, process and predict | |
if st.session_state["uploaded_file"] is not None: | |
image = Image.open(st.session_state["uploaded_file"]) | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
# Preprocess the image | |
input_image = transform(image).unsqueeze(0).to(device) | |
# Perform prediction | |
with torch.no_grad(): | |
outputs = model(input_image) | |
_, predicted_class = torch.max(outputs, 1) | |
# Map predicted class to labels | |
class_names = ['Not_Streptococcosis', 'Streptococcosis'] | |
predicted_label = class_names[predicted_class.item()] | |
# Display the result | |
st.write("### Prediction Result:") | |
if predicted_label == "Streptococcosis": | |
st.error(f"The sample is predicted as **{predicted_label}**") | |
else: | |
st.success(f"The sample is predicted as **{predicted_label}**") | |
# Button to reset the file uploader | |
if st.button("Upload Another Image"): | |
st.session_state["uploaded_file"] = None | |
st.rerun() # Use st.rerun instead of st.experimental_rerun | |