|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
|
|
from train import SteganographyNet, extract_message, get_device |
|
|
|
|
|
try: |
|
from safetensors.torch import load_file as load_safetensors |
|
except ImportError: |
|
print("safetensors not installed. Run: pip install safetensors") |
|
|
|
|
|
device = get_device() |
|
model = SteganographyNet(message_length=1024).to(device) |
|
|
|
|
|
model_path = 'model.safetensors' |
|
if model_path.endswith('.safetensors'): |
|
model.load_state_dict(load_safetensors(model_path)) |
|
else: |
|
model.load_state_dict(torch.load(model_path)) |
|
|
|
model.eval() |
|
|
|
|
|
extracted_message = extract_message(model, 'decode_me_3.png') |
|
print(f"Extracted message: {extracted_message}") |
|
|