WavMark / 停泊 /app.py
xxxfortest's picture
update
7342b32
import pdb
import time
import streamlit as st
import os
from utils import wm_add_v2, file_reader, model_util, wm_decode_v2, bin_util
from models import my_model_v7_recover
import torch
import uuid
import datetime
import numpy as np
import soundfile
from huggingface_hub import hf_hub_download, HfApi
# Function to add watermark to audio
def add_watermark(audio_path, watermark_text):
assert len(watermark_text) == 5
start_bit, msg_bit, watermark = wm_add_v2.create_parcel_message(len_start_bit, 32, watermark_text)
data, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, 16000)
_, signal_wmd, time_cost = wm_add_v2.add_watermark(watermark, data, 16000, 0.1, device, model)
tmp_file_name = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + str(uuid.uuid4()) + ".wav"
tmp_file_path = '/tmp/' + tmp_file_name
soundfile.write(tmp_file_path, signal_wmd, sr)
return tmp_file_path
# Function to decode watermark from audio
def decode_watermark(audio_path):
data, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, 16000)
data = data[0:5 * sr]
start_bit = wm_add_v2.fix_pattern[0:len_start_bit]
support_count, mean_result, results = wm_decode_v2.extract_watermark_v2(
data,
start_bit,
0.1,
16000,
0.3,
model,
device, "best")
if mean_result is None:
return "No Watermark"
payload = mean_result[len_start_bit:]
return bin_util.binArray2HexStr(payload)
# Main web app
def main():
max_upload_size = 20 * 1024 * 1024 # 20 MB in bytes
if "def_value" not in st.session_state:
st.session_state.def_value = bin_util.binArray2HexStr(np.random.choice([0, 1], size=32 - len_start_bit))
st.title("Neural Audio Watermark")
st.write("Choose the action you want to perform:")
action = st.selectbox("Select Action", ["Add Watermark", "Decode Watermark"])
if action == "Add Watermark":
audio_file = st.file_uploader("Upload Audio File (WAV)", type=["wav"], accept_multiple_files=False,
max_upload_size=max_upload_size)
if audio_file:
tmp_input_audio_file = os.path.join("/tmp/", audio_file.name)
with open(tmp_input_audio_file, "wb") as f:
f.write(audio_file.getbuffer())
st.audio(tmp_input_audio_file, format="audio/wav")
watermark_text = st.text_input("Enter Watermark Text (5 English letters)", value=st.session_state.def_value)
add_watermark_button = st.button("Add Watermark", key="add_watermark_btn")
if add_watermark_button: # 点击按钮后执行的
if audio_file and watermark_text:
with st.spinner("Adding Watermark..."):
# add_watermark_button.empty()
# st.button("Add Watermark", disabled=True)
# st.button("Add Watermark", disabled=True, key="add_watermark_btn_disabled")
t1 = time.time()
watermarked_audio = add_watermark(tmp_input_audio_file, watermark_text)
encode_time_cost = time.time() - t1
st.write("Watermarked Audio:")
st.audio(watermarked_audio, format="audio/wav")
st.write("Time Cost:%d seconds" % encode_time_cost)
# st.button("Add Watermark", disabled=False)
elif action == "Decode Watermark":
audio_file = st.file_uploader("Upload Audio File (WAV/MP3)", type=["wav", "mp3"], accept_multiple_files=False,
max_upload_size=max_upload_size)
if audio_file:
if st.button("Decode Watermark"):
# 1.保存
tmp_file_for_decode_path = os.path.join("/tmp/", audio_file.name)
with open(tmp_file_for_decode_path, "wb") as f:
f.write(audio_file.getbuffer())
# 2.执行
with st.spinner("Decoding..."):
t1 = time.time()
decoded_watermark = decode_watermark(tmp_file_for_decode_path)
decode_cost = time.time() - t1
print("decoded_watermark", decoded_watermark)
# Display the decoded watermark
st.write("Decoded Watermark:", decoded_watermark)
st.write("Time Cost:%d seconds" % (decode_cost))
def load_model(resume_path):
n_fft = 1000
hop_length = 400
# https://huggingface.co/M4869/InvertibleWM/blob/main/step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl
# api_key = st.secrets["api_key"]
# print(api_key, api_key)
api_key = "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"
model_ckpt_path = hf_hub_download(repo_id="M4869/InvertibleWM",
filename="step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl",
token=api_key
)
# print("model_ckpt_path", model_ckpt_path)
resume_path = model_ckpt_path
# return
model = my_model_v7_recover.Model(16000, 32, n_fft, hop_length,
use_recover_layer=False, num_layers=8).to(device)
checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
state_dict = model_util.map_state_dict(checkpoint['model'])
model.load_state_dict(state_dict, strict=True)
model.eval()
return model
if __name__ == "__main__":
len_start_bit = 12
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = load_model("./data/step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl")
main()
# decode_watermark("/Users/my/Downloads/7a95b353a46893903e9f946c24170b210ce14e8c52c63bb2ab3d144e.wav")