gibhug commited on
Commit
d68f082
·
verified ·
1 Parent(s): 0cf8746

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+
7
+ # Load the trained model
8
+ MODEL_PATH = "resnet_model.pth" # Update with your actual model path
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = torch.load(MODEL_PATH, map_location=device)
11
+ model.eval()
12
+
13
+ # Define the image transformation pipeline
14
+ transform = transforms.Compose([
15
+ transforms.Resize((224, 224)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
18
+ ])
19
+
20
+ # Streamlit UI
21
+ st.title("Saliva Disease Detection App")
22
+ st.subheader("Predict Streptococcosis vs NOT Streptococcosis from uploaded saliva images")
23
+
24
+ # Initialize session state for managing the uploaded file
25
+ if "uploaded_file" not in st.session_state:
26
+ st.session_state["uploaded_file"] = None
27
+
28
+ # File uploader
29
+ uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"], key="file_uploader")
30
+
31
+ if uploaded_file is not None:
32
+ st.session_state["uploaded_file"] = uploaded_file
33
+
34
+ # If a file has been uploaded, process and predict
35
+ if st.session_state["uploaded_file"] is not None:
36
+ image = Image.open(st.session_state["uploaded_file"])
37
+ st.image(image, caption="Uploaded Image", use_container_width=True)
38
+
39
+ # Preprocess the image
40
+ input_image = transform(image).unsqueeze(0).to(device)
41
+
42
+ # Perform prediction
43
+ with torch.no_grad():
44
+ outputs = model(input_image)
45
+ probabilities = F.softmax(outputs, dim=1) # Convert to probabilities
46
+ _, predicted_class = torch.max(outputs, 1)
47
+
48
+ # Map predicted class to labels
49
+ class_names = ['Not_Streptococcosis', 'Streptococcosis']
50
+ predicted_label = class_names[predicted_class.item()]
51
+ predicted_probability = probabilities[0][predicted_class.item()].item() * 100 # Convert to percentage
52
+
53
+ # Display the result
54
+ st.write("### Prediction Result:")
55
+ if predicted_label == "Streptococcosis":
56
+ st.error(f"The sample is predicted as **{predicted_label}** with **{predicted_probability:.2f}%** probability.")
57
+ else:
58
+ st.success(f"The sample is predicted as **{predicted_label}** with **{predicted_probability:.2f}%** probability.")
59
+
60
+ # Show probabilities for all classes
61
+ st.write("### Class Probabilities:")
62
+ for idx, class_name in enumerate(class_names):
63
+ st.write(f"- **{class_name}**: {probabilities[0][idx].item() * 100:.2f}%")
64
+
65
+ # Button to reset the file uploader
66
+ if st.button("Upload Another Image"):
67
+ st.session_state["uploaded_file"] = None
68
+ st.rerun()