import torch from torchvision import transforms from PIL import Image # Import your model class and extraction functions from train import SteganographyNet, extract_message, get_device # Import safetensors if available try: from safetensors.torch import load_file as load_safetensors except ImportError: print("safetensors not installed. Run: pip install safetensors") # Load the saved model device = get_device() model = SteganographyNet(message_length=1024).to(device) # message_length doesn't matter for extraction # Load model weights based on file extension model_path = 'model.safetensors' # or 'stego_model_3.safetensors' if model_path.endswith('.safetensors'): model.load_state_dict(load_safetensors(model_path)) else: model.load_state_dict(torch.load(model_path)) model.eval() # Test extraction extracted_message = extract_message(model, 'decode_me_3.png') print(f"Extracted message: {extracted_message}")