import torch import torch.nn as nn import torch.optim as optim from torchvision import transforms from PIL import Image import numpy as np import torch.backends.mps from math import exp import torch.nn.functional as F class SteganographyNet(nn.Module): def __init__(self, message_length): super(SteganographyNet, self).__init__() self.message_length = message_length # Modified encoder with skip connection self.encoder_initial = nn.Sequential( nn.Conv2d(4, 64, 3, padding=1), nn.GroupNorm(8, 64), nn.SiLU(), ) self.encoder_backbone = nn.Sequential( nn.Conv2d(64, 128, 3, padding=1), nn.GroupNorm(16, 128), nn.SiLU(), SEBlock(128), nn.Conv2d(128, 128, 3, padding=2, dilation=2), nn.GroupNorm(16, 128), nn.SiLU(), ResidualBlock(128), nn.Conv2d(128, 64, 1), nn.GroupNorm(8, 64), nn.SiLU(), ) self.encoder_final = nn.Sequential( nn.Conv2d(64, 3, 3, padding=1), nn.Sigmoid() ) # Add decoder self.decoder = nn.Sequential( # Initial feature extraction nn.Conv2d(3, 64, 3, padding=1), nn.GroupNorm(8, 64), nn.SiLU(), # Feature processing nn.Conv2d(64, 128, 3, padding=1), nn.GroupNorm(16, 128), nn.SiLU(), SEBlock(128), ResidualBlock(128), nn.Conv2d(128, 64, 3, padding=1), nn.GroupNorm(8, 64), nn.SiLU(), # Final message extraction nn.Conv2d(64, 1, 3, padding=1), nn.Sigmoid() ) def encode(self, x): # Extract original image original_img = x[:, :3, :, :] # Process through encoder initial = self.encoder_initial(x) processed = self.encoder_backbone(initial) output = self.encoder_final(processed) # Add skip connection from input image return 0.9 * original_img + 0.1 * output def forward(self, x): # This can be used for end-to-end training encoded = self.encode(x) decoded = self.decoder(encoded) return encoded, decoded # Add these new blocks class SEBlock(nn.Module): def __init__(self, channels, reduction=16): super(SEBlock, self).__init__() self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( nn.Linear(channels, channels // reduction, bias=False), nn.SiLU(), nn.Linear(channels // reduction, channels, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.squeeze(x).view(b, c) y = self.excitation(y).view(b, c, 1, 1) return x * y.expand_as(x) class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.gn1 = nn.GroupNorm(8, channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.gn2 = nn.GroupNorm(8, channels) self.silu = nn.SiLU() def forward(self, x): residual = x out = self.silu(self.gn1(self.conv1(x))) out = self.gn2(self.conv2(out)) out += residual return self.silu(out) class SSIM(nn.Module): def __init__(self, window_size=11, size_average=True, channel=3): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = channel self.window = self.create_window(window_size, channel) def gaussian(self, window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() def create_window(self, window_size, channel): _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() return window def ssim(self, img1, img2, window, size_average=True): mu1 = F.conv2d(img1, window, padding=self.window_size//2, groups=self.channel) mu2 = F.conv2d(img2, window, padding=self.window_size//2, groups=self.channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1*img1, window, padding=self.window_size//2, groups=self.channel) - mu1_sq sigma2_sq = F.conv2d(img2*img2, window, padding=self.window_size//2, groups=self.channel) - mu2_sq sigma12 = F.conv2d(img1*img2, window, padding=self.window_size//2, groups=self.channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) def forward(self, img1, img2): # Make sure window is on the same device as input window = self.window.to(img1.device) return self.ssim(img1, img2, window, self.size_average) def get_device(): if torch.backends.mps.is_available(): return torch.device("mps") elif torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") def text_to_binary_tensor(text, height, width): """Convert text to binary tensor""" # Convert text to UTF-8 bytes, then to binary binary = ''.join(format(byte, '08b') for byte in text.encode('utf-8')) # Pad binary string to fill image binary = binary + '0' * (height * width - len(binary)) binary_array = np.array([int(b) for b in binary]).reshape(1, height, width) return torch.FloatTensor(binary_array) def binary_tensor_to_text(tensor): """Convert binary tensor back to text""" # Threshold the tensor values to get clear 0s and 1s binary = ''.join([str(int(round(float(b)))) for b in tensor.flatten()]) # Process in 8-bit chunks message = '' for i in range(0, len(binary) - 7, 8): # Changed to ensure we don't go past the end byte = binary[i:i+8] try: char = chr(int(byte, 2)) if ord(char) == 0: # Stop at null terminator break message += char except ValueError: continue # Skip invalid bytes return message def embed_message(model, image_path, message, output_path): """Embed a message into an image using the trained model""" device = get_device() # Load and preprocess image (now using 512x512) transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() ]) img = transform(Image.open(image_path)).unsqueeze(0).to(device) # Prepare message (now using 512x512) msg_tensor = text_to_binary_tensor(message, 512, 512).to(device) msg_tensor = msg_tensor.unsqueeze(0) # Concatenate image and message x = torch.cat([img, msg_tensor], dim=1) # Generate stego image model.eval() with torch.no_grad(): stego_img = model.encode(x) # Save image stego_img = stego_img.squeeze(0).cpu() transforms.ToPILImage()(stego_img).save(output_path, 'PNG') return True def extract_message(model, image_path): """Extract hidden message from image using the trained model""" device = get_device() transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() ]) stego_img = transform(Image.open(image_path)).unsqueeze(0).to(device) # Extract message model.eval() with torch.no_grad(): msg_tensor = model.decoder(stego_img) # Threshold the values more aggressively msg_tensor = (msg_tensor > 0.5).float() # Convert to text with better error handling try: # Convert binary tensor to bytes binary = msg_tensor.cpu().numpy().flatten() binary_str = ''.join(['1' if b > 0.5 else '0' for b in binary]) # Process in chunks until we hit invalid UTF-8 or null terminator bytes_data = bytearray() for i in range(0, len(binary_str) - 7, 8): byte = binary_str[i:i+8] byte_val = int(byte, 2) if byte_val == 0: # Stop at null terminator break bytes_data.append(byte_val) # Decode with explicit UTF-8 handling message = bytes_data.decode('utf-8', errors='ignore') # Clean up any trailing null characters message = message.split('\x00')[0] except Exception as e: print(f"Error during message extraction: {e}") message = "" return message def train_model(image_path, message, epochs=600): """Train the steganography model""" device = get_device() model = SteganographyNet(len(message) * 8).to(device) # Use modern optimizer with weight decay optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) # Use cosine annealing scheduler scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-6) # Use modern loss combination mse_loss = nn.MSELoss() ssim_loss = SSIM().to(device) # Structural Similarity Loss # Prepare data (now using 512x512) transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() ]) img = transform(Image.open(image_path)).unsqueeze(0).to(device) msg_tensor = text_to_binary_tensor(message, 512, 512).to(device) msg_tensor = msg_tensor.unsqueeze(0) # Training loop for epoch in range(epochs): # Forward pass x = torch.cat([img, msg_tensor], dim=1) stego_img = model.encode(x) recovered_msg = model.decoder(stego_img) # Calculate losses with perceptual components image_loss = 0.95 * mse_loss(stego_img, img) + 0.05 * (1 - ssim_loss(stego_img, img)) message_loss = mse_loss(recovered_msg, msg_tensor) # Adjust alpha to prioritize image quality alpha = min(epoch / (epochs * 0.4), 0.3) # Cap at 0.3 instead of 1.0 total_loss = (1 - alpha) * image_loss + (alpha * 5) * message_loss # Reduced message weight from 10 to 5 # Backward pass optimizer.zero_grad() total_loss.backward() optimizer.step() scheduler.step() if (epoch + 1) % 100 == 0: print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss.item():.4f}') return model # Example usage if __name__ == "__main__": input_image = "steno_2(1).jpg" output_image = "decode_me_3.png" secret_message = "" # Train model model = train_model(input_image, secret_message) # Save model weights torch.save(model.state_dict(), 'stego_model_3.pth') # Embed message embed_message(model, input_image, secret_message, output_image) print("Message embedded successfully!") # Extract message extracted_message = extract_message(model, output_image) print(f"Extracted message: {extracted_message}")