import streamlit as st from PIL import Image import torch from torchvision import transforms import torch.nn.functional as F # Load the trained model MODEL_PATH = "resnet_model.pth" # Update with your actual model path device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load(MODEL_PATH, map_location=device) model.eval() # Define the image transformation pipeline transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Streamlit UI st.title("Saliva Disease Detection App") st.subheader("Predict Streptococcal infection vs No Streptococcal infection from saliva images") # Initialize session state for managing the uploaded file if "uploaded_file" not in st.session_state: st.session_state["uploaded_file"] = None # File uploader uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"], key="file_uploader") if uploaded_file is not None: st.session_state["uploaded_file"] = uploaded_file # If a file has been uploaded, process and predict if st.session_state["uploaded_file"] is not None: image = Image.open(st.session_state["uploaded_file"]) st.image(image, caption="Uploaded Image", use_container_width=True) # Preprocess the image input_image = transform(image).unsqueeze(0).to(device) # Perform prediction with torch.no_grad(): outputs = model(input_image) probabilities = F.softmax(outputs, dim=1) # Convert to probabilities _, predicted_class = torch.max(outputs, 1) # Map predicted class to labels #class_names = ['Not_Streptococcosis', 'Streptococcosis'] class_names = ['Not_Streptococcal_Infection', 'Streptococcal_Infection'] predicted_label = class_names[predicted_class.item()] predicted_probability = probabilities[0][predicted_class.item()].item() * 100 # Convert to percentage # Display the result st.write("### Prediction Result:") if predicted_label == "Streptococcal_Infection": st.error(f"The sample is predicted as **{predicted_label}** with **{predicted_probability:.2f}%** probability.") else: st.success(f"The sample is predicted as **{predicted_label}** with **{predicted_probability:.2f}%** probability.") # Show probabilities for all classes st.write("### Class Probabilities:") for idx, class_name in enumerate(class_names): st.write(f"- **{class_name}**: {probabilities[0][idx].item() * 100:.2f}%") # Button to reset the file uploader if st.button("Upload Another Image"): st.session_state["uploaded_file"] = None st.rerun()