z11h reshinthadith commited on
Commit
cd843ed
·
0 Parent(s):

Duplicate from stabilityai/stablelm-tuned-alpha-chat

Browse files

Co-authored-by: reshinth.adith <reshinthadith@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +111 -0
  4. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stablelm Tuned Alpha Chat
3
+ emoji: 👀
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: stabilityai/stablelm-tuned-alpha-chat
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList
4
+ import time
5
+ import numpy as np
6
+ from torch.nn import functional as F
7
+ import os
8
+ # auth_key = os.environ["HF_ACCESS_TOKEN"]
9
+ print(f"Starting to load the model to memory")
10
+ m = AutoModelForCausalLM.from_pretrained(
11
+ "stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
12
+ tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b")
13
+ generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
14
+ print(f"Sucessfully loaded the model to the memory")
15
+
16
+ start_message = """<|SYSTEM|># StableAssistant
17
+ - StableAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
18
+ - StableAssistant is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
19
+ - StableAssistant is more than just an information source, StableAssistant is also able to write poetry, short stories, and make jokes.
20
+ - StableAssistant will refuse to participate in anything that could harm a human."""
21
+
22
+
23
+ class StopOnTokens(StoppingCriteria):
24
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
25
+ stop_ids = [50278, 50279, 50277, 1, 0]
26
+ for stop_id in stop_ids:
27
+ if input_ids[0][-1] == stop_id:
28
+ return True
29
+ return False
30
+
31
+
32
+ def contrastive_generate(text, bad_text):
33
+ with torch.no_grad():
34
+ tokens = tok(text, return_tensors="pt")[
35
+ 'input_ids'].cuda()[:, :4096-1024]
36
+ bad_tokens = tok(bad_text, return_tensors="pt")[
37
+ 'input_ids'].cuda()[:, :4096-1024]
38
+ history = None
39
+ bad_history = None
40
+ curr_output = list()
41
+ for i in range(1024):
42
+ out = m(tokens, past_key_values=history, use_cache=True)
43
+ logits = out.logits
44
+ history = out.past_key_values
45
+ bad_out = m(bad_tokens, past_key_values=bad_history,
46
+ use_cache=True)
47
+ bad_logits = bad_out.logits
48
+ bad_history = bad_out.past_key_values
49
+ probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
50
+ bad_probs = F.softmax(bad_logits.float(), dim=-1)[0][-1].cpu()
51
+ logits = torch.log(probs)
52
+ bad_logits = torch.log(bad_probs)
53
+ logits[probs > 0.1] = logits[probs > 0.1] - bad_logits[probs > 0.1]
54
+ probs = F.softmax(logits)
55
+ out = int(torch.multinomial(probs, 1))
56
+ if out in [50278, 50279, 50277, 1, 0]:
57
+ break
58
+ else:
59
+ curr_output.append(out)
60
+ out = np.array([out])
61
+ tokens = torch.from_numpy(np.array([out])).to(
62
+ tokens.device)
63
+ bad_tokens = torch.from_numpy(np.array([out])).to(
64
+ tokens.device)
65
+ return tok.decode(curr_output)
66
+
67
+
68
+ def generate(text, bad_text=None):
69
+ stop = StopOnTokens()
70
+ result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
71
+ temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
72
+ return result[0]["generated_text"].replace(text, "")
73
+
74
+
75
+ def user(user_message, history):
76
+ history = history + [[user_message, ""]]
77
+ return "", history, history
78
+
79
+
80
+ def bot(history, curr_system_message):
81
+ messages = curr_system_message + \
82
+ "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
83
+ for item in history])
84
+ output = generate(messages)
85
+ history[-1][1] = output
86
+ time.sleep(1)
87
+ return history, history
88
+
89
+
90
+ with gr.Blocks() as demo:
91
+ history = gr.State([])
92
+ gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
93
+ gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
94
+ chatbot = gr.Chatbot().style(height=500)
95
+ with gr.Row():
96
+ with gr.Column(scale=0.70):
97
+ msg = gr.Textbox(label="", placeholder="Chat Message Box")
98
+ with gr.Column(scale=0.30, min_width=0):
99
+ with gr.Row():
100
+ submit = gr.Button("Submit")
101
+ clear = gr.Button("Clear")
102
+ system_msg = gr.Textbox(
103
+ start_message, label="System Message", interactive=False, visible=False)
104
+
105
+ msg.submit(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
106
+ fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
107
+ submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
108
+ fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
109
+ clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
110
+ demo.queue(concurrency_count=5)
111
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ numpy