Tawkat commited on
Commit
86f1147
·
verified ·
1 Parent(s): 3befda6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -103
app.py CHANGED
@@ -1,119 +1,90 @@
1
- import os
2
- from threading import Thread
3
- from typing import Iterator
4
-
5
  import gradio as gr
6
- import spaces
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
- MAX_MAX_NEW_TOKENS = 2048
11
- DEFAULT_MAX_NEW_TOKENS = 1024
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
- #if torch.cuda.is_available():
15
- model_id = "openai-community/gpt2" #"mistralai/Mistral-7B-Instruct-v0.1"
16
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto")
17
- tokenizer = AutoTokenizer.from_pretrained(model_id)
18
- #tokenizer.use_default_system_prompt = False
19
 
 
 
 
 
 
 
 
20
 
21
- #@spaces.GPU
22
  def generate(
23
- message: str,
24
- chat_history: list[tuple[str, str]],
25
- system_prompt: str,
26
- max_new_tokens: int = 1024,
27
- temperature: float = 0.6,
28
- top_p: float = 0.9,
29
- top_k: int = 50,
30
- ) -> Iterator[str]:
31
- '''conversation = []
32
- if system_prompt:
33
- conversation.append({"role": "system", "content": system_prompt})
34
- for user, assistant in chat_history:
35
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
36
- conversation.append({"role": "user", "content": message})'''
37
- prompt = "<s>"
38
- for user_prompt, bot_response in chat_history:
39
- prompt += f"[INST] {user_prompt} [/INST]"
40
- prompt += f" {bot_response}</s> "
41
- prompt += f"[INST] {message} [/INST]"
42
 
43
- input_ids = tokenizer(prompt, return_tensors="pt")['input_ids']
44
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
45
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
46
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
47
- input_ids = input_ids.to(model.device)
48
-
49
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
50
  generate_kwargs = dict(
51
- {"input_ids": input_ids},
52
- streamer=streamer,
53
  max_new_tokens=max_new_tokens,
54
- do_sample=True,
55
  top_p=top_p,
56
- top_k=top_k,
57
- temperature=temperature,
58
- num_beams=1,
59
  )
60
- t = Thread(target=model.generate, kwargs=generate_kwargs)
61
- t.start()
62
 
63
- outputs = []
64
- for text in streamer:
65
- outputs.append(text)
66
- yield "".join(outputs)
67
 
 
 
 
 
 
 
 
68
 
69
- chat_interface = gr.ChatInterface(
70
- fn=generate,
71
- additional_inputs=[
72
- gr.Textbox(label="System prompt", lines=6),
73
- gr.Slider(
74
- label="Max new tokens",
75
- minimum=1,
76
- maximum=MAX_MAX_NEW_TOKENS,
77
- step=1,
78
- value=DEFAULT_MAX_NEW_TOKENS,
79
- ),
80
- gr.Slider(
81
- label="Temperature",
82
- minimum=0.1,
83
- maximum=4.0,
84
- step=0.1,
85
- value=0.6,
86
- ),
87
- gr.Slider(
88
- label="Top-p (nucleus sampling)",
89
- minimum=0.05,
90
- maximum=1.0,
91
- step=0.05,
92
- value=0.9,
93
- ),
94
- gr.Slider(
95
- label="Top-k",
96
- minimum=1,
97
- maximum=1000,
98
- step=1,
99
- value=50,
100
- ),
101
- ],
102
- stop_btn=None,
103
- examples=[
104
- ["Hello there! How are you doing?"],
105
- ["Can you explain briefly to me what is the Python programming language?"],
106
- ["Explain the plot of Cinderella in a sentence."],
107
- ["How many hours does it take a man to eat a Helicopter?"],
108
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
109
- ],
110
- )
111
 
112
- with gr.Blocks(css="style.css") as demo:
113
- gr.Markdown(DESCRIPTION)
114
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
115
- chat_interface.render()
116
- gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- if __name__ == "__main__":
119
- demo.queue(max_size=5).launch()
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
 
 
 
2
  import gradio as gr
 
 
 
3
 
4
+ client = InferenceClient(
5
+ "mistralai/Mistral-7B-Instruct-v0.1"
6
+ )
7
 
 
 
 
 
 
8
 
9
+ def format_prompt(message, history):
10
+ prompt = "<s>"
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
16
 
 
17
  def generate(
18
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
+ ):
20
+ temperature = float(temperature)
21
+ if temperature < 1e-2:
22
+ temperature = 1e-2
23
+ top_p = float(top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
25
  generate_kwargs = dict(
26
+ temperature=temperature,
 
27
  max_new_tokens=max_new_tokens,
 
28
  top_p=top_p,
29
+ repetition_penalty=repetition_penalty,
30
+ do_sample=True,
31
+ seed=42,
32
  )
 
 
33
 
34
+ formatted_prompt = format_prompt(prompt, history)
 
 
 
35
 
36
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
+ output = ""
38
+
39
+ for response in stream:
40
+ output += response.token.text
41
+ yield output
42
+ return output
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ additional_inputs=[
46
+ gr.Slider(
47
+ label="Temperature",
48
+ value=0.9,
49
+ minimum=0.0,
50
+ maximum=1.0,
51
+ step=0.05,
52
+ interactive=True,
53
+ info="Higher values produce more diverse outputs",
54
+ ),
55
+ gr.Slider(
56
+ label="Max new tokens",
57
+ value=256,
58
+ minimum=0,
59
+ maximum=1048,
60
+ step=64,
61
+ interactive=True,
62
+ info="The maximum numbers of new tokens",
63
+ ),
64
+ gr.Slider(
65
+ label="Top-p (nucleus sampling)",
66
+ value=0.90,
67
+ minimum=0.0,
68
+ maximum=1,
69
+ step=0.05,
70
+ interactive=True,
71
+ info="Higher values sample more low-probability tokens",
72
+ ),
73
+ gr.Slider(
74
+ label="Repetition penalty",
75
+ value=1.2,
76
+ minimum=1.0,
77
+ maximum=2.0,
78
+ step=0.05,
79
+ interactive=True,
80
+ info="Penalize repeated tokens",
81
+ )
82
+ ]
83
+
84
 
85
+ gr.ChatInterface(
86
+ fn=generate,
87
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
88
+ additional_inputs=additional_inputs,
89
+ title="""Mistral 7B"""
90
+ ).launch(show_api=False)