|
import os |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
|
|
class ImageEnhancementModel(nn.Module): |
|
def __init__(self): |
|
super(ImageEnhancementModel, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1) |
|
self.relu1 = nn.ReLU() |
|
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) |
|
self.relu2 = nn.ReLU() |
|
self.conv3 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1) |
|
|
|
def forward(self, x): |
|
|
|
x = self.relu1(self.conv1(x)) |
|
x = self.relu2(self.conv2(x)) |
|
x = self.conv3(x) |
|
return x |
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, data_dir): |
|
self.data_dir = data_dir |
|
self.image_files = os.listdir(data_dir) |
|
self.transform = transforms.Compose([transforms.ToTensor()]) |
|
|
|
def __len__(self): |
|
return len(self.image_files) |
|
|
|
def __getitem__(self, idx): |
|
img_name = os.path.join(self.data_dir, self.image_files[idx]) |
|
image = Image.open(img_name) |
|
|
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
image = self.transform(image) |
|
return image |
|
|
|
|
|
|
|
batch_size = 8 |
|
learning_rate = 0.001 |
|
num_epochs = 50 |
|
|
|
model = ImageEnhancementModel() |
|
|
|
|
|
criterion = nn.MSELoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
|
|
train_dataset = CustomDataset(data_dir='before') |
|
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
for epoch in range(num_epochs): |
|
for data in train_loader: |
|
|
|
outputs = model(data) |
|
|
|
|
|
target_data = CustomDataset(data_dir='after') |
|
target_data = next(iter(target_data)) |
|
|
|
loss = criterion(outputs, target_data) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}') |
|
|
|
|
|
torch.save(model.state_dict(), 'image_enhancement_model.pth') |
|
|
|
|
|
model.eval() |
|
|
|
|
|
input_image = Image.open('testb.jpg') |
|
input_image = train_dataset.transform(input_image).unsqueeze(0) |
|
|
|
|
|
enhanced_image = model(input_image) |
|
|
|
|
|
output_image = enhanced_image.squeeze().permute(1, 2, 0).detach().cpu().numpy() |
|
output_image = (output_image + 1) / 2.0 * 255.0 |
|
output_image = output_image.astype('uint8') |
|
Image.fromarray(output_image).save('enhanced_image.jpg') |
|
|