Spaces:
Runtime error
Runtime error
Antoine Chaffin
commited on
Commit
•
6758170
1
Parent(s):
87801f9
Test debug using mistral
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ hf_token = os.getenv('HF_TOKEN')
|
|
13 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
14 |
|
15 |
parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
|
16 |
-
parser.add_argument('--model', '-m', type=str, default="
|
17 |
parser.add_argument('--key', '-k', type=int, default=42,
|
18 |
help='The seed of the pseudo random number generator')
|
19 |
|
@@ -41,16 +41,9 @@ def embed(user, max_length, window_size, method, prompt):
|
|
41 |
uid = USERS.index(user)
|
42 |
watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
|
43 |
prompt = get_prompt(prompt)
|
44 |
-
print("prompt:", prompt)
|
45 |
watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
|
46 |
max_length=max_length+LEN_DEFAULT_PROMPT, method=method, prompt=prompt)
|
47 |
-
|
48 |
-
print(watermarked_text)
|
49 |
-
print("===")
|
50 |
-
print(watermarked_texts[0].split("[/INST]")[1][0])
|
51 |
-
print("===")
|
52 |
-
print(watermarked_texts[0].split("[/INST]"))
|
53 |
-
return watermarked_texts[0].split("[/INST]")[1][0]
|
54 |
|
55 |
def detect(attacked_text, window_size, method, prompt):
|
56 |
watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
|
@@ -67,10 +60,11 @@ def detect(attacked_text, window_size, method, prompt):
|
|
67 |
return label
|
68 |
|
69 |
def get_prompt(message: str) -> str:
|
70 |
-
texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
|
71 |
-
# The first user input is _not_ stripped
|
72 |
-
texts.append(f'{message} [/INST]')
|
73 |
-
return ''.join(texts)
|
|
|
74 |
|
75 |
with gr.Blocks() as demo:
|
76 |
gr.Markdown("""# LLM generation watermarking
|
|
|
13 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
14 |
|
15 |
parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
|
16 |
+
parser.add_argument('--model', '-m', type=str, default="mistralai/Mistral-7B-Instruct-v0.1", help='Language model')
|
17 |
parser.add_argument('--key', '-k', type=int, default=42,
|
18 |
help='The seed of the pseudo random number generator')
|
19 |
|
|
|
41 |
uid = USERS.index(user)
|
42 |
watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
|
43 |
prompt = get_prompt(prompt)
|
|
|
44 |
watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
|
45 |
max_length=max_length+LEN_DEFAULT_PROMPT, method=method, prompt=prompt)
|
46 |
+
return watermarked_texts[0].split("[/INST]")[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
def detect(attacked_text, window_size, method, prompt):
|
49 |
watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
|
|
|
60 |
return label
|
61 |
|
62 |
def get_prompt(message: str) -> str:
|
63 |
+
# texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
|
64 |
+
# # The first user input is _not_ stripped
|
65 |
+
# texts.append(f'{message} [/INST]')
|
66 |
+
# return ''.join(texts)
|
67 |
+
return f"[INST] "+message+ " [/INST]"
|
68 |
|
69 |
with gr.Blocks() as demo:
|
70 |
gr.Markdown("""# LLM generation watermarking
|