from flask import Flask,request, send_file import os import io import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision.datasets import ImageFolder import torchvision.transforms as transforms from PIL import Image import matplotlib.pyplot as plt from datetime import datetime app = Flask(__name__) @app.route('/', methods=['GET']) def dummy_get(): return "Welcome to Flask App" @app.route('/upload', methods=['POST']) def upload_file(): class CNN_Stage3(nn.Module): def __init__(self, in_channels, out_channels): super(CNN_Stage3, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, dilation=2, padding=1) self.relu = nn.ReLU() self.pool = nn.MaxPool2d(kernel_size=2, stride=1) def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.pool(x) x = self.relu(x) return x class CNN_Stage1(nn.Module): def __init__(self, in_channels, out_channels): super(CNN_Stage1, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=1) self.relu = nn.ReLU() self.pool = nn.MaxPool2d(kernel_size=2, stride=1) def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.pool(x) x = self.relu(x) return x class CNN(nn.Module): def __init__(self, num_classes): super(CNN, self).__init__() self.cnn_stage_1 = CNN_Stage1(3, 6) self.cnn_stage_2 = CNN_Stage1(6, 12) self.cnn_stage_3 = CNN_Stage3(12, 24) self.cnn_stage_4 = CNN_Stage1(24, 48) self.cnn_stage_5 = CNN_Stage1(48, 96) self.fc1 = nn.Linear(96 * 3 * 3, 64) self.fc2 = nn.Linear(64, num_classes) self.relu = nn.ReLU() def forward(self, x): x = self.cnn_stage_1(x) x = self.cnn_stage_2(x) x = self.cnn_stage_3(x) x = self.cnn_stage_4(x) x = self.cnn_stage_5(x) x = x.view(x.size(0), -1) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x class CustomDataset(Dataset): def __init__(self, root_dir, transform=None): self.dataset = ImageFolder(root_dir, transform=transform) self.classes = self.dataset.classes def __len__(self): return len(self.dataset) def __getitem__(self, idx): image, label = self.dataset[idx] return image, label # Example usage: dataset_path = 'aug_data' transform = transforms.Compose([ transforms.Resize((22, 22)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) custom_dataset = CustomDataset(root_dir=dataset_path, transform=transform) num_classes = len(custom_dataset.classes) batch_size = 32 data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True) model = CNN(num_classes) optimizer = optim.Adam(model.parameters(), lr=0.001) # Load the model checkpoint = torch.load("model_cnn_final.pth") model.load_state_dict(checkpoint['model_state_dict']) # Assuming optimizer was saved in the checkpoint optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] # Print model's parameter names for name, param in model.named_parameters(): print(name) if 'file' not in request.files: return 'No file part' file = request.files['file'] # Generate a unique filename using a timestamp timestamp = datetime.now().strftime('%Y%m%d%H%M%S') unique_filename = f"{timestamp}_{file.filename}" file.save(f'uploads/{unique_filename}') input_image = Image.open(f'uploads/{unique_filename}') input_tensor = transform(input_image) input_batch = input_tensor.unsqueeze(0) # Use the loaded model to make predictions with torch.no_grad(): output = model(input_batch) # If the user does not select a file, the browser submits an empty file without a filename if file.filename == '': return 'No selected file' else: # Interpret the predictions class_names = ['cancer', 'no- cancer'] _, predicted_class = torch.max(output, 1) predicted_label = class_names[predicted_class.item()] print(f'The image is classified as: {predicted_label}') plt.imshow(input_image) # print(f'The image is classified as: {predicted_label}') return f'The image is classified as: {predicted_label}' if __name__ == "__main__": app.run(host='0.0.0.0',debug=True, port=5000)