charenji / train.py
takarajordan's picture
Update train.py
2a60667 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.backends.mps
from math import exp
import torch.nn.functional as F
class SteganographyNet(nn.Module):
def __init__(self, message_length):
super(SteganographyNet, self).__init__()
self.message_length = message_length
# Modified encoder with skip connection
self.encoder_initial = nn.Sequential(
nn.Conv2d(4, 64, 3, padding=1),
nn.GroupNorm(8, 64),
nn.SiLU(),
)
self.encoder_backbone = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1),
nn.GroupNorm(16, 128),
nn.SiLU(),
SEBlock(128),
nn.Conv2d(128, 128, 3, padding=2, dilation=2),
nn.GroupNorm(16, 128),
nn.SiLU(),
ResidualBlock(128),
nn.Conv2d(128, 64, 1),
nn.GroupNorm(8, 64),
nn.SiLU(),
)
self.encoder_final = nn.Sequential(
nn.Conv2d(64, 3, 3, padding=1),
nn.Sigmoid()
)
# Add decoder
self.decoder = nn.Sequential(
# Initial feature extraction
nn.Conv2d(3, 64, 3, padding=1),
nn.GroupNorm(8, 64),
nn.SiLU(),
# Feature processing
nn.Conv2d(64, 128, 3, padding=1),
nn.GroupNorm(16, 128),
nn.SiLU(),
SEBlock(128),
ResidualBlock(128),
nn.Conv2d(128, 64, 3, padding=1),
nn.GroupNorm(8, 64),
nn.SiLU(),
# Final message extraction
nn.Conv2d(64, 1, 3, padding=1),
nn.Sigmoid()
)
def encode(self, x):
# Extract original image
original_img = x[:, :3, :, :]
# Process through encoder
initial = self.encoder_initial(x)
processed = self.encoder_backbone(initial)
output = self.encoder_final(processed)
# Add skip connection from input image
return 0.9 * original_img + 0.1 * output
def forward(self, x):
# This can be used for end-to-end training
encoded = self.encode(x)
decoded = self.decoder(encoded)
return encoded, decoded
# Add these new blocks
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super(SEBlock, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.SiLU(),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
y = self.excitation(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.gn1 = nn.GroupNorm(8, channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.gn2 = nn.GroupNorm(8, channels)
self.silu = nn.SiLU()
def forward(self, x):
residual = x
out = self.silu(self.gn1(self.conv1(x)))
out = self.gn2(self.conv2(out))
out += residual
return self.silu(out)
class SSIM(nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
self.window = self.create_window(window_size, channel)
def gaussian(self, window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(self, window_size, channel):
_1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def ssim(self, img1, img2, window, size_average=True):
mu1 = F.conv2d(img1, window, padding=self.window_size//2, groups=self.channel)
mu2 = F.conv2d(img2, window, padding=self.window_size//2, groups=self.channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1*img1, window, padding=self.window_size//2, groups=self.channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding=self.window_size//2, groups=self.channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
def forward(self, img1, img2):
# Make sure window is on the same device as input
window = self.window.to(img1.device)
return self.ssim(img1, img2, window, self.size_average)
def get_device():
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
def text_to_binary_tensor(text, height, width):
"""Convert text to binary tensor"""
# Convert text to UTF-8 bytes, then to binary
binary = ''.join(format(byte, '08b') for byte in text.encode('utf-8'))
# Pad binary string to fill image
binary = binary + '0' * (height * width - len(binary))
binary_array = np.array([int(b) for b in binary]).reshape(1, height, width)
return torch.FloatTensor(binary_array)
def binary_tensor_to_text(tensor):
"""Convert binary tensor back to text"""
# Threshold the tensor values to get clear 0s and 1s
binary = ''.join([str(int(round(float(b)))) for b in tensor.flatten()])
# Process in 8-bit chunks
message = ''
for i in range(0, len(binary) - 7, 8): # Changed to ensure we don't go past the end
byte = binary[i:i+8]
try:
char = chr(int(byte, 2))
if ord(char) == 0: # Stop at null terminator
break
message += char
except ValueError:
continue # Skip invalid bytes
return message
def embed_message(model, image_path, message, output_path):
"""Embed a message into an image using the trained model"""
device = get_device()
# Load and preprocess image (now using 512x512)
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
img = transform(Image.open(image_path)).unsqueeze(0).to(device)
# Prepare message (now using 512x512)
msg_tensor = text_to_binary_tensor(message, 512, 512).to(device)
msg_tensor = msg_tensor.unsqueeze(0)
# Concatenate image and message
x = torch.cat([img, msg_tensor], dim=1)
# Generate stego image
model.eval()
with torch.no_grad():
stego_img = model.encode(x)
# Save image
stego_img = stego_img.squeeze(0).cpu()
transforms.ToPILImage()(stego_img).save(output_path, 'PNG')
return True
def extract_message(model, image_path):
"""Extract hidden message from image using the trained model"""
device = get_device()
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
stego_img = transform(Image.open(image_path)).unsqueeze(0).to(device)
# Extract message
model.eval()
with torch.no_grad():
msg_tensor = model.decoder(stego_img)
# Threshold the values more aggressively
msg_tensor = (msg_tensor > 0.5).float()
# Convert to text with better error handling
try:
# Convert binary tensor to bytes
binary = msg_tensor.cpu().numpy().flatten()
binary_str = ''.join(['1' if b > 0.5 else '0' for b in binary])
# Process in chunks until we hit invalid UTF-8 or null terminator
bytes_data = bytearray()
for i in range(0, len(binary_str) - 7, 8):
byte = binary_str[i:i+8]
byte_val = int(byte, 2)
if byte_val == 0: # Stop at null terminator
break
bytes_data.append(byte_val)
# Decode with explicit UTF-8 handling
message = bytes_data.decode('utf-8', errors='ignore')
# Clean up any trailing null characters
message = message.split('\x00')[0]
except Exception as e:
print(f"Error during message extraction: {e}")
message = ""
return message
def train_model(image_path, message, epochs=600):
"""Train the steganography model"""
device = get_device()
model = SteganographyNet(len(message) * 8).to(device)
# Use modern optimizer with weight decay
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
# Use cosine annealing scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-6)
# Use modern loss combination
mse_loss = nn.MSELoss()
ssim_loss = SSIM().to(device) # Structural Similarity Loss
# Prepare data (now using 512x512)
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
img = transform(Image.open(image_path)).unsqueeze(0).to(device)
msg_tensor = text_to_binary_tensor(message, 512, 512).to(device)
msg_tensor = msg_tensor.unsqueeze(0)
# Training loop
for epoch in range(epochs):
# Forward pass
x = torch.cat([img, msg_tensor], dim=1)
stego_img = model.encode(x)
recovered_msg = model.decoder(stego_img)
# Calculate losses with perceptual components
image_loss = 0.95 * mse_loss(stego_img, img) + 0.05 * (1 - ssim_loss(stego_img, img))
message_loss = mse_loss(recovered_msg, msg_tensor)
# Adjust alpha to prioritize image quality
alpha = min(epoch / (epochs * 0.4), 0.3) # Cap at 0.3 instead of 1.0
total_loss = (1 - alpha) * image_loss + (alpha * 5) * message_loss # Reduced message weight from 10 to 5
# Backward pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
scheduler.step()
if (epoch + 1) % 100 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss.item():.4f}')
return model
# Example usage
if __name__ == "__main__":
input_image = "steno_2(1).jpg"
output_image = "decode_me_3.png"
secret_message = ""
# Train model
model = train_model(input_image, secret_message)
# Save model weights
torch.save(model.state_dict(), 'stego_model_3.pth')
# Embed message
embed_message(model, input_image, secret_message, output_image)
print("Message embedded successfully!")
# Extract message
extracted_message = extract_message(model, output_image)
print(f"Extracted message: {extracted_message}")