arjunanand13 commited on
Commit
c65ba42
1 Parent(s): 3c6573c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -34
app.py CHANGED
@@ -1,48 +1,168 @@
1
- from huggingface_hub import InferenceClient
 
 
 
 
2
  import gradio as gr
 
3
 
4
- client = InferenceClient("meta-llama/Meta-Llama-3.1-8B")
 
 
5
 
6
- def format_prompt(message, history):
7
- fixed_prompt= """ """
8
- prompt = f"<s>{fixed_prompt}"
9
- for user_prompt, bot_response in history:
10
- prompt += f"\n User:{user_prompt}\n LLM Response:{bot_response}"
11
- prompt += f"\nUser: {message}\nLLM Response:"
12
 
13
- return prompt
 
 
 
 
14
 
15
- def generate(
16
- prompt, history, temperature=0.1, max_new_tokens=2048, top_p=0.8, repetition_penalty=1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  ):
18
- temperature = float(temperature)
19
- if temperature < 1e-2:
20
- temperature = 1e-2
21
- top_p = float(top_p)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  generate_kwargs = dict(
24
- temperature=temperature,
25
- max_new_tokens=max_new_tokens,
26
- top_p=top_p,
27
- repetition_penalty=repetition_penalty,
28
- do_sample=True,
29
- seed=42,
 
 
 
30
  )
31
 
32
- formatted_prompt = format_prompt(prompt, history)
33
-
34
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
35
-
36
- yield stream
 
 
 
37
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- demo = gr.ChatInterface (fn=generate,
41
- title="Mood-Based Music Recommender",
42
- retry_btn=None,
43
- undo_btn=None,
44
- clear_btn=None,
45
- description="<span style='font-size: larger; font-weight: bold;'>Hi! I'm your music buddy—tell me about your mood and the type of tunes you're in the mood for today!</span>",
46
- )
47
 
48
- demo.queue().launch()
 
 
1
+ import os
2
+ import time
3
+ import spaces
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
7
+ from threading import Thread
8
 
9
+ MODEL_LIST = ["meta-llama/Meta-Llama-3.1-8B-Instruct"]
10
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL = os.environ.get("MODEL_ID")
12
 
13
+ TITLE = "<h1><center>Meta-Llama3.1-8B</center></h1>"
 
 
 
 
 
14
 
15
+ PLACEHOLDER = """
16
+ <center>
17
+ <p>Hi! How can I help you today?</p>
18
+ </center>
19
+ """
20
 
21
+
22
+ CSS = """
23
+ .duplicate-button {
24
+ margin: auto !important;
25
+ color: white !important;
26
+ background: black !important;
27
+ border-radius: 100vh !important;
28
+ }
29
+ h3 {
30
+ text-align: center;
31
+ }
32
+ """
33
+
34
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
35
+
36
+ quantization_config = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.bfloat16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type= "nf4")
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ MODEL,
45
+ torch_dtype=torch.bfloat16,
46
+ device_map="auto",
47
+ quantization_config=quantization_config)
48
+
49
+ @spaces.GPU()
50
+ def stream_chat(
51
+ message: str,
52
+ history: list,
53
+ system_prompt: str,
54
+ temperature: float = 0.8,
55
+ max_new_tokens: int = 1024,
56
+ top_p: float = 1.0,
57
+ top_k: int = 20,
58
+ penalty: float = 1.2,
59
  ):
60
+ print(f'message: {message}')
61
+ print(f'history: {history}')
 
 
62
 
63
+ conversation = [
64
+ {"role": "system", "content": system_prompt}
65
+ ]
66
+ for prompt, answer in history:
67
+ conversation.extend([
68
+ {"role": "user", "content": prompt},
69
+ {"role": "assistant", "content": answer},
70
+ ])
71
+
72
+ conversation.append({"role": "user", "content": message})
73
+
74
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
75
+
76
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
77
+
78
  generate_kwargs = dict(
79
+ input_ids=input_ids,
80
+ max_new_tokens = max_new_tokens,
81
+ do_sample = False if temperature == 0 else True,
82
+ top_p = top_p,
83
+ top_k = top_k,
84
+ temperature = temperature,
85
+ repetition_penalty=penalty,
86
+ eos_token_id=[128001,128008,128009],
87
+ streamer=streamer,
88
  )
89
 
90
+ with torch.no_grad():
91
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
92
+ thread.start()
93
+
94
+ buffer = ""
95
+ for new_text in streamer:
96
+ buffer += new_text
97
+ yield buffer
98
 
99
+
100
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
101
 
102
+ with gr.Blocks(css=CSS, theme="soft") as demo:
103
+ gr.HTML(TITLE)
104
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
105
+ gr.ChatInterface(
106
+ fn=stream_chat,
107
+ chatbot=chatbot,
108
+ fill_height=True,
109
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
110
+ additional_inputs=[
111
+ gr.Textbox(
112
+ value="You are a helpful assistant",
113
+ label="System Prompt",
114
+ render=False,
115
+ ),
116
+ gr.Slider(
117
+ minimum=0,
118
+ maximum=1,
119
+ step=0.1,
120
+ value=0.8,
121
+ label="Temperature",
122
+ render=False,
123
+ ),
124
+ gr.Slider(
125
+ minimum=128,
126
+ maximum=8192,
127
+ step=1,
128
+ value=1024,
129
+ label="Max new tokens",
130
+ render=False,
131
+ ),
132
+ gr.Slider(
133
+ minimum=0.0,
134
+ maximum=1.0,
135
+ step=0.1,
136
+ value=1.0,
137
+ label="top_p",
138
+ render=False,
139
+ ),
140
+ gr.Slider(
141
+ minimum=1,
142
+ maximum=20,
143
+ step=1,
144
+ value=20,
145
+ label="top_k",
146
+ render=False,
147
+ ),
148
+ gr.Slider(
149
+ minimum=0.0,
150
+ maximum=2.0,
151
+ step=0.1,
152
+ value=1.2,
153
+ label="Repetition penalty",
154
+ render=False,
155
+ ),
156
+ ],
157
+ examples=[
158
+ ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
159
+ ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
160
+ ["Tell me a random fun fact about the Roman Empire."],
161
+ ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
162
+ ],
163
+ cache_examples=False,
164
+ )
165
 
 
 
 
 
 
 
 
166
 
167
+ if __name__ == "__main__":
168
+ demo.launch()