File size: 939 Bytes
1a61279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torchvision import transforms
from PIL import Image

# Import your model class and extraction functions
from train import SteganographyNet, extract_message, get_device

# Import safetensors if available
try:
    from safetensors.torch import load_file as load_safetensors
except ImportError:
    print("safetensors not installed. Run: pip install safetensors")

# Load the saved model
device = get_device()
model = SteganographyNet(message_length=1024).to(device)  # message_length doesn't matter for extraction

# Load model weights based on file extension
model_path = 'model.safetensors'  # or 'stego_model_3.safetensors'
if model_path.endswith('.safetensors'):
    model.load_state_dict(load_safetensors(model_path))
else:
    model.load_state_dict(torch.load(model_path))

model.eval()

# Test extraction
extracted_message = extract_message(model, 'decode_me_3.png')
print(f"Extracted message: {extracted_message}")