charenji / decoder.py
takarajordan's picture
Upload 2 files
1a61279 verified
raw
history blame
939 Bytes
import torch
from torchvision import transforms
from PIL import Image
# Import your model class and extraction functions
from train import SteganographyNet, extract_message, get_device
# Import safetensors if available
try:
from safetensors.torch import load_file as load_safetensors
except ImportError:
print("safetensors not installed. Run: pip install safetensors")
# Load the saved model
device = get_device()
model = SteganographyNet(message_length=1024).to(device) # message_length doesn't matter for extraction
# Load model weights based on file extension
model_path = 'model.safetensors' # or 'stego_model_3.safetensors'
if model_path.endswith('.safetensors'):
model.load_state_dict(load_safetensors(model_path))
else:
model.load_state_dict(torch.load(model_path))
model.eval()
# Test extraction
extracted_message = extract_message(model, 'decode_me_3.png')
print(f"Extracted message: {extracted_message}")