reshinthadith commited on
Commit
bd0305e
1 Parent(s): b058a81

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList
4
+ import time
5
+ import numpy as np
6
+ from torch.nn import functional as F
7
+
8
+
9
+ m = AutoModelForCausalLM.from_pretrained("/mnt/nvme/home/dakota/ckpts/stablelm/7B-sft-combined/checkpoint-8000", torch_dtype=torch.float16).cuda()
10
+ tok = AutoTokenizer.from_pretrained("/mnt/nvme/home/dakota/stablelm_tokenizer")
11
+ generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
12
+
13
+
14
+ start_message = """<|SYSTEM|># StableAssistant
15
+ - StableAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
16
+ - StableAssistant is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
17
+ - StableAssistant is more than just an information source, StableAssistant is also able to write poetry, short stories, and make jokes.
18
+ - StableAssistant will refuse to participate in anything that could harm a human."""
19
+
20
+
21
+ class StopOnTokens(StoppingCriteria):
22
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
23
+ stop_ids = [50278, 50279, 50277, 1, 0]
24
+ for stop_id in stop_ids:
25
+ if input_ids[0][-1] == stop_id:
26
+ return True
27
+ return False
28
+
29
+
30
+ def contrastive_generate(text, bad_text):
31
+ with torch.no_grad():
32
+ tokens = tok(text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
33
+ bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
34
+ history = None
35
+ bad_history = None
36
+ curr_output = list()
37
+ for i in range(1024):
38
+ out = m(tokens, past_key_values=history, use_cache=True)
39
+ logits = out.logits
40
+ history = out.past_key_values
41
+ bad_out = m(bad_tokens, past_key_values=bad_history, use_cache=True)
42
+ bad_logits = bad_out.logits
43
+ bad_history = bad_out.past_key_values
44
+ probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
45
+ bad_probs = F.softmax(bad_logits.float(), dim=-1)[0][-1].cpu()
46
+ logits = torch.log(probs)
47
+ bad_logits = torch.log(bad_probs)
48
+ logits[probs > 0.1] = logits[probs > 0.1] - bad_logits[probs > 0.1]
49
+ probs = F.softmax(logits)
50
+ out = int(torch.multinomial(probs, 1))
51
+ if out in [50278, 50279, 50277, 1, 0]:
52
+ break
53
+ else:
54
+ curr_output.append(out)
55
+ out = np.array([out])
56
+ tokens = torch.from_numpy(np.array([out])).to(
57
+ tokens.device)
58
+ bad_tokens = torch.from_numpy(np.array([out])).to(
59
+ tokens.device)
60
+ return tok.decode(curr_output)
61
+
62
+ def generate(text, bad_text=None):
63
+ stop = StopOnTokens()
64
+ result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True, temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
65
+ return result[0]["generated_text"].replace(text, "")
66
+
67
+
68
+ def user(user_message, history):
69
+ return "", history + [[user_message, ""]]
70
+
71
+
72
+ def bot(history, curr_system_message):
73
+ messages = curr_system_message + "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]]) for item in history])
74
+ output = generate(messages)
75
+ history[-1][1] = output
76
+ time.sleep(1)
77
+ return history
78
+
79
+
80
+ def system_update(msg):
81
+ global curr_system_message
82
+ curr_system_message = msg
83
+
84
+
85
+ with gr.Blocks() as demo:
86
+ with gr.Row():
87
+ with gr.Column():
88
+ chatbot = gr.Chatbot([])
89
+ clear = gr.Button("Clear")
90
+ with gr.Column():
91
+ system_msg = gr.Textbox(start_message, label="System Message", interactive=True)
92
+ msg = gr.Textbox(label="Chat Message")
93
+
94
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
95
+ bot, [chatbot, system_msg], chatbot
96
+ )
97
+ system_msg.change(system_update, system_msg, None, queue=False)
98
+ clear.click(lambda: None, None, chatbot, queue=False)
99
+ demo.launch(share=True)