File size: 939 Bytes
1a61279 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from torchvision import transforms
from PIL import Image
# Import your model class and extraction functions
from train import SteganographyNet, extract_message, get_device
# Import safetensors if available
try:
from safetensors.torch import load_file as load_safetensors
except ImportError:
print("safetensors not installed. Run: pip install safetensors")
# Load the saved model
device = get_device()
model = SteganographyNet(message_length=1024).to(device) # message_length doesn't matter for extraction
# Load model weights based on file extension
model_path = 'model.safetensors' # or 'stego_model_3.safetensors'
if model_path.endswith('.safetensors'):
model.load_state_dict(load_safetensors(model_path))
else:
model.load_state_dict(torch.load(model_path))
model.eval()
# Test extraction
extracted_message = extract_message(model, 'decode_me_3.png')
print(f"Extracted message: {extracted_message}")
|