Antoine Chaffin commited on
Commit
6758170
1 Parent(s): 87801f9

Test debug using mistral

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -13,7 +13,7 @@ hf_token = os.getenv('HF_TOKEN')
13
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
14
 
15
  parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
16
- parser.add_argument('--model', '-m', type=str, default="meta-llama/Llama-2-7b-chat-hf", help='Language model')
17
  parser.add_argument('--key', '-k', type=int, default=42,
18
  help='The seed of the pseudo random number generator')
19
 
@@ -41,16 +41,9 @@ def embed(user, max_length, window_size, method, prompt):
41
  uid = USERS.index(user)
42
  watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
43
  prompt = get_prompt(prompt)
44
- print("prompt:", prompt)
45
  watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
46
  max_length=max_length+LEN_DEFAULT_PROMPT, method=method, prompt=prompt)
47
- print("===")
48
- print(watermarked_text)
49
- print("===")
50
- print(watermarked_texts[0].split("[/INST]")[1][0])
51
- print("===")
52
- print(watermarked_texts[0].split("[/INST]"))
53
- return watermarked_texts[0].split("[/INST]")[1][0]
54
 
55
  def detect(attacked_text, window_size, method, prompt):
56
  watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
@@ -67,10 +60,11 @@ def detect(attacked_text, window_size, method, prompt):
67
  return label
68
 
69
  def get_prompt(message: str) -> str:
70
- texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
71
- # The first user input is _not_ stripped
72
- texts.append(f'{message} [/INST]')
73
- return ''.join(texts)
 
74
 
75
  with gr.Blocks() as demo:
76
  gr.Markdown("""# LLM generation watermarking
 
13
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
14
 
15
  parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
16
+ parser.add_argument('--model', '-m', type=str, default="mistralai/Mistral-7B-Instruct-v0.1", help='Language model')
17
  parser.add_argument('--key', '-k', type=int, default=42,
18
  help='The seed of the pseudo random number generator')
19
 
 
41
  uid = USERS.index(user)
42
  watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
43
  prompt = get_prompt(prompt)
 
44
  watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
45
  max_length=max_length+LEN_DEFAULT_PROMPT, method=method, prompt=prompt)
46
+ return watermarked_texts[0].split("[/INST]")[1]
 
 
 
 
 
 
47
 
48
  def detect(attacked_text, window_size, method, prompt):
49
  watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
 
60
  return label
61
 
62
  def get_prompt(message: str) -> str:
63
+ # texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
64
+ # # The first user input is _not_ stripped
65
+ # texts.append(f'{message} [/INST]')
66
+ # return ''.join(texts)
67
+ return f"[INST] "+message+ " [/INST]"
68
 
69
  with gr.Blocks() as demo:
70
  gr.Markdown("""# LLM generation watermarking