vilarin commited on
Commit
51a7d9e
1 Parent(s): 1e24216

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
+ import os
7
+ import time
8
+
9
+
10
+
11
+
12
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
+ MODEL_ID = "CohereForAI/aya-23-35B"
14
+ MODEL_NAME = MODEL_ID.split("/")[-1]
15
+
16
+ TITLE = "<h1><center>Aya-23-35B-Chatbox</center></h1>"
17
+
18
+ DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></center></h3>'
19
+
20
+ CSS = """
21
+ .duplicate-button {
22
+ margin: auto !important;
23
+ color: white !important;
24
+ background: black !important;
25
+ border-radius: 100vh !important;
26
+ }
27
+ """
28
+
29
+
30
+ #QUANTIZE
31
+ QUANTIZE_4BIT = True
32
+ USE_GRAD_CHECKPOINTING = True
33
+ TRAIN_BATCH_SIZE = 2
34
+ TRAIN_MAX_SEQ_LENGTH = 512
35
+ USE_FLASH_ATTENTION = False
36
+ GRAD_ACC_STEPS = 16
37
+
38
+ quantization_config = None
39
+ if QUANTIZE_4BIT:
40
+ quantization_config = BitsAndBytesConfig(
41
+ load_in_4bit=True,
42
+ bnb_4bit_quant_type="nf4",
43
+ bnb_4bit_use_double_quant=True,
44
+ bnb_4bit_compute_dtype=torch.bfloat16,
45
+ )
46
+
47
+ attn_implementation = None
48
+ if USE_FLASH_ATTENTION:
49
+ attn_implementation="flash_attention_2"
50
+
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ MODEL_NAME,
53
+ quantization_config=quantization_config,
54
+ attn_implementation=attn_implementation,
55
+ torch_dtype=torch.bfloat16,
56
+ device_map="auto",
57
+ )
58
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
59
+
60
+ @spaces.GPU
61
+ def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
62
+ conversation = []
63
+ for prompt, answer in history:
64
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
65
+
66
+ conversation.append({"role": "user", "content": message})
67
+
68
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
69
+
70
+ gen_tokens= model.generate(
71
+ input_ids,
72
+ streamer=streamer,
73
+ max_new_tokens=max_new_tokens,
74
+ do_sample=True,
75
+ temperature=temperature,
76
+ )
77
+
78
+ gen_text = tokenizer.decode(gen_tokens[0])
79
+
80
+ return gen_text
81
+
82
+
83
+ chatbot = gr.Chatbot(height=450)
84
+
85
+ with gr.Blocks(css=CSS) as demo:
86
+ gr.HTML(TITLE)
87
+ gr.HTML(DESCRIPTION)
88
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
89
+ gr.ChatInterface(
90
+ fn=stream_chat,
91
+ chatbot=chatbot,
92
+ fill_height=True,
93
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
94
+ additional_inputs=[
95
+ gr.Slider(
96
+ minimum=0,
97
+ maximum=1,
98
+ step=0.1,
99
+ value=0.8,
100
+ label="Temperature",
101
+ render=False,
102
+ ),
103
+ gr.Slider(
104
+ minimum=128,
105
+ maximum=4096,
106
+ step=1,
107
+ value=1024,
108
+ label="Max new tokens",
109
+ render=False,
110
+ ),
111
+ ],
112
+ examples=[
113
+ ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
114
+ ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
115
+ ["Tell me a random fun fact about the Roman Empire."],
116
+ ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
117
+ ],
118
+ cache_examples=False,
119
+ )
120
+
121
+
122
+ if __name__ == "__main__":
123
+ demo.launch()