Spaces:
Runtime error
Runtime error
from flask import Flask,request, send_file | |
import os | |
import io | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision.datasets import ImageFolder | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from datetime import datetime | |
app = Flask(__name__) | |
def dummy_get(): | |
return "Welcome to Flask App" | |
def upload_file(): | |
class CNN_Stage3(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(CNN_Stage3, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, dilation=2, padding=1) | |
self.relu = nn.ReLU() | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=1) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.relu(x) | |
x = self.pool(x) | |
x = self.relu(x) | |
return x | |
class CNN_Stage1(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(CNN_Stage1, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=1) | |
self.relu = nn.ReLU() | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=1) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.relu(x) | |
x = self.pool(x) | |
x = self.relu(x) | |
return x | |
class CNN(nn.Module): | |
def __init__(self, num_classes): | |
super(CNN, self).__init__() | |
self.cnn_stage_1 = CNN_Stage1(3, 6) | |
self.cnn_stage_2 = CNN_Stage1(6, 12) | |
self.cnn_stage_3 = CNN_Stage3(12, 24) | |
self.cnn_stage_4 = CNN_Stage1(24, 48) | |
self.cnn_stage_5 = CNN_Stage1(48, 96) | |
self.fc1 = nn.Linear(96 * 3 * 3, 64) | |
self.fc2 = nn.Linear(64, num_classes) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
x = self.cnn_stage_1(x) | |
x = self.cnn_stage_2(x) | |
x = self.cnn_stage_3(x) | |
x = self.cnn_stage_4(x) | |
x = self.cnn_stage_5(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc1(x) | |
x = self.relu(x) | |
x = self.fc2(x) | |
return x | |
class CustomDataset(Dataset): | |
def __init__(self, root_dir, transform=None): | |
self.dataset = ImageFolder(root_dir, transform=transform) | |
self.classes = self.dataset.classes | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
image, label = self.dataset[idx] | |
return image, label | |
# Example usage: | |
dataset_path = 'aug_data' | |
transform = transforms.Compose([ | |
transforms.Resize((22, 22)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
custom_dataset = CustomDataset(root_dir=dataset_path, transform=transform) | |
num_classes = len(custom_dataset.classes) | |
batch_size = 32 | |
data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True) | |
model = CNN(num_classes) | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
# Load the model | |
checkpoint = torch.load("model_cnn_final.pth") | |
model.load_state_dict(checkpoint['model_state_dict']) | |
# Assuming optimizer was saved in the checkpoint | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
epoch = checkpoint['epoch'] | |
loss = checkpoint['loss'] | |
# Print model's parameter names | |
for name, param in model.named_parameters(): | |
print(name) | |
if 'file' not in request.files: | |
return 'No file part' | |
file = request.files['file'] | |
# Generate a unique filename using a timestamp | |
timestamp = datetime.now().strftime('%Y%m%d%H%M%S') | |
unique_filename = f"{timestamp}_{file.filename}" | |
file.save(f'uploads/{unique_filename}') | |
input_image = Image.open(f'uploads/{unique_filename}') | |
input_tensor = transform(input_image) | |
input_batch = input_tensor.unsqueeze(0) | |
# Use the loaded model to make predictions | |
with torch.no_grad(): | |
output = model(input_batch) | |
# If the user does not select a file, the browser submits an empty file without a filename | |
if file.filename == '': | |
return 'No selected file' | |
else: | |
# Interpret the predictions | |
class_names = ['cancer', 'no- cancer'] | |
_, predicted_class = torch.max(output, 1) | |
predicted_label = class_names[predicted_class.item()] | |
print(f'The image is classified as: {predicted_label}') | |
plt.imshow(input_image) | |
# print(f'The image is classified as: {predicted_label}') | |
return f'The image is classified as: {predicted_label}' | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0',debug=True, port=5000) | |