|
import streamlit as st
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import joblib
|
|
import tensorflow as tf
|
|
from PIL import Image, ImageDraw
|
|
import os
|
|
|
|
|
|
|
|
class YOLOv5(nn.Module):
|
|
def __init__(self):
|
|
super(YOLOv5, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
|
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
|
self.fc1 = nn.Linear(64 * 16 * 16, 512)
|
|
self.fc2 = nn.Linear(512, 128)
|
|
self.fc3 = nn.Linear(128, 7)
|
|
|
|
def forward(self, x):
|
|
x = self.pool(torch.relu(self.conv1(x)))
|
|
x = self.pool(torch.relu(self.conv2(x)))
|
|
x = self.pool(torch.relu(self.conv3(x)))
|
|
x = x.reshape(-1, 64 * 16 * 16)
|
|
x = torch.relu(self.fc1(x))
|
|
x = torch.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return x
|
|
|
|
yolo_model = YOLOv5()
|
|
yolo_model.load_state_dict(torch.load('yolo_model.pth'))
|
|
yolo_model.eval()
|
|
|
|
|
|
cnn_model = tf.keras.models.load_model('cnn_model.h5')
|
|
|
|
|
|
elastic_net_model = joblib.load('elastic_net_model.joblib')
|
|
|
|
|
|
class HybridYOLOCNN(nn.Module):
|
|
def __init__(self):
|
|
super(HybridYOLOCNN, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
|
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
|
self.fc1 = nn.Linear(64 * 16 * 16, 512)
|
|
self.fc2 = nn.Linear(512, 128)
|
|
self.fc3 = nn.Linear(128, 7)
|
|
|
|
def forward(self, x):
|
|
x = self.pool(torch.relu(self.conv1(x)))
|
|
x = self.pool(torch.relu(self.conv2(x)))
|
|
x = self.pool(torch.relu(self.conv3(x)))
|
|
x = x.reshape(-1, 64 * 16 * 16)
|
|
x = torch.relu(self.fc1(x))
|
|
x = torch.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return x
|
|
|
|
hybrid_model = HybridYOLOCNN()
|
|
hybrid_model.load_state_dict(torch.load('hybrid_yolo_cnn_model.pth'))
|
|
hybrid_model.eval()
|
|
|
|
|
|
def predict_with_box(img, model_type="yolo"):
|
|
img_tensor = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float()
|
|
|
|
|
|
if model_type == "yolo":
|
|
output = yolo_model(img_tensor)
|
|
pred_class = torch.argmax(output, dim=1).item()
|
|
x_center, y_center = 64, 64
|
|
elif model_type == "hybrid":
|
|
output = hybrid_model(img_tensor)
|
|
pred_class = torch.argmax(output, dim=1).item()
|
|
x_center, y_center = 64, 64
|
|
elif model_type == "cnn":
|
|
img_array = np.expand_dims(img, axis=0) / 255.0
|
|
output = cnn_model.predict(img_array)
|
|
pred_class = np.argmax(output, axis=1)[0]
|
|
x_center, y_center = 64, 64
|
|
elif model_type == "elastic_net":
|
|
img_flattened = img.flatten().reshape(1, -1)
|
|
pred_class = int(np.clip(np.round(elastic_net_model.predict(img_flattened)), 0, 3)[0])
|
|
x_center, y_center = 84, 84
|
|
|
|
|
|
img_with_box = Image.fromarray(img)
|
|
draw = ImageDraw.Draw(img_with_box)
|
|
box_size = 20
|
|
box = (x_center - box_size, y_center - box_size, x_center + box_size, y_center + box_size)
|
|
draw.rectangle(box, outline="red", width=3)
|
|
|
|
return img_with_box, pred_class
|
|
|
|
|
|
st.title("Skin Lesion Classification with Multiple Models")
|
|
st.write("Upload up to 10 images, and get predictions from each model with highlighted areas.")
|
|
|
|
uploaded_files = st.file_uploader("Choose images", accept_multiple_files=True, type=["jpg", "jpeg", "png"])
|
|
if uploaded_files:
|
|
for file in uploaded_files[:10]:
|
|
|
|
img = Image.open(file).convert("RGB")
|
|
img = img.resize((128, 128))
|
|
img_np = np.array(img)
|
|
|
|
st.image(img, caption="Uploaded Image", use_column_width=True)
|
|
|
|
|
|
yolo_img, yolo_pred = predict_with_box(img_np, model_type="yolo")
|
|
cnn_img, cnn_pred = predict_with_box(img_np, model_type="cnn")
|
|
elastic_net_img, elastic_net_pred = predict_with_box(img_np, model_type="elastic_net")
|
|
hybrid_img, hybrid_pred = predict_with_box(img_np, model_type="hybrid")
|
|
|
|
|
|
col1, col2, col3, col4 = st.columns(4)
|
|
with col1:
|
|
st.image(yolo_img, caption=f"YOLOv5 Prediction: {yolo_pred}", use_column_width=True)
|
|
with col2:
|
|
st.image(cnn_img, caption=f"CNN Prediction: {cnn_pred}", use_column_width=True)
|
|
with col3:
|
|
st.image(elastic_net_img, caption=f"ElasticNet Prediction: {elastic_net_pred}", use_column_width=True)
|
|
with col4:
|
|
st.image(hybrid_img, caption=f"Hybrid YOLO-CNN Prediction: {hybrid_pred}", use_column_width=True)
|
|
|