Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import EasyDel
|
2 |
+
import jax.lax
|
3 |
+
from EasyDel import JAXServer, get_mesh
|
4 |
+
from fjutils import get_float_dtype_by_name
|
5 |
+
from EasyDel.transform import llama_from_pretrained
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
import gradio as gr
|
8 |
+
from fjutils.tracker import initialise_tracking, get_mem
|
9 |
+
import argparse
|
10 |
+
from fjutils import make_shard_and_gather_fns, match_partition_rules
|
11 |
+
import threading
|
12 |
+
import typing
|
13 |
+
import IPython
|
14 |
+
import logging
|
15 |
+
import jax.numpy as jnp
|
16 |
+
import time
|
17 |
+
|
18 |
+
logging.basicConfig(
|
19 |
+
level=logging.INFO
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
instruct = 'Context:\n{context}\nQuestion:\nYes or No question, can you answer to ' \
|
24 |
+
'""{question}?"" only and only by using provided context?'
|
25 |
+
|
26 |
+
|
27 |
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer " \
|
28 |
+
"as helpfully as possible, while being safe. Your answers should not" \
|
29 |
+
" include any harmful, unethical, racist, sexist, toxic, dangerous, or " \
|
30 |
+
"illegal content. Please ensure that your responses are socially unbiased " \
|
31 |
+
"and positive in nature.\nIf a question does not make any sense, or is not " \
|
32 |
+
"factually coherent, explain why instead of answering something not correct. If " \
|
33 |
+
"you don't know the answer to a question, please don't share false information."
|
34 |
+
|
35 |
+
|
36 |
+
def get_prompt_llama2_format(message: str, chat_history,
|
37 |
+
system_prompt: str) -> str:
|
38 |
+
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
39 |
+
do_strip = False
|
40 |
+
for user_input, response in chat_history:
|
41 |
+
user_input = user_input.strip() if do_strip else user_input
|
42 |
+
do_strip = True
|
43 |
+
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
44 |
+
message = message.strip() if do_strip else message
|
45 |
+
texts.append(f'{message} [/INST]')
|
46 |
+
return ''.join(texts)
|
47 |
+
|
48 |
+
|
49 |
+
class InTimeDataFinderJaxServerLlama2Type(JAXServer):
|
50 |
+
def __init__(self, config=None):
|
51 |
+
super().__init__(config=config)
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def load_from_torch(cls, repo_id, config=None):
|
55 |
+
with jax.default_device(jax.devices('cpu')[0]):
|
56 |
+
param, config_model = llama_from_pretrained(
|
57 |
+
repo_id
|
58 |
+
)
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
60 |
+
model = EasyDel.FlaxLlamaForCausalLM(
|
61 |
+
config=config_model,
|
62 |
+
dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
|
63 |
+
param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
|
64 |
+
precision=jax.lax.Precision('fastest'),
|
65 |
+
_do_init=False
|
66 |
+
)
|
67 |
+
return cls.load_from_params(
|
68 |
+
config_model=config_model,
|
69 |
+
model=model,
|
70 |
+
config=config,
|
71 |
+
params=param,
|
72 |
+
tokenizer=tokenizer,
|
73 |
+
add_param_field=True,
|
74 |
+
do_memory_log=False
|
75 |
+
)
|
76 |
+
|
77 |
+
@classmethod
|
78 |
+
def load_from_jax(cls, repo_id, checkpoint_path, config_repo=None, config=None):
|
79 |
+
from huggingface_hub import hf_hub_download
|
80 |
+
path = hf_hub_download(repo_id, checkpoint_path)
|
81 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
82 |
+
config_model = EasyDel.LlamaConfig.from_pretrained(config_repo or repo_id)
|
83 |
+
model = EasyDel.FlaxLlamaForCausalLM(
|
84 |
+
config=config_model,
|
85 |
+
dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
|
86 |
+
param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
|
87 |
+
precision=jax.lax.Precision('fastest'),
|
88 |
+
_do_init=False
|
89 |
+
)
|
90 |
+
return cls.load(
|
91 |
+
path=path,
|
92 |
+
config_model=config_model,
|
93 |
+
model=model,
|
94 |
+
config=config,
|
95 |
+
tokenizer=tokenizer,
|
96 |
+
add_param_field=True,
|
97 |
+
do_memory_log=False
|
98 |
+
)
|
99 |
+
|
100 |
+
def process_gradio_chat(self, prompt, history, max_new_tokens, greedy, pbar=gr.Progress()):
|
101 |
+
string = get_prompt_llama2_format(
|
102 |
+
message=prompt,
|
103 |
+
chat_history=history,
|
104 |
+
system_prompt=DEFAULT_SYSTEM_PROMPT
|
105 |
+
)
|
106 |
+
if not self.config.stream_tokens_for_gradio:
|
107 |
+
response, _ = self.process(
|
108 |
+
string=string,
|
109 |
+
greedy=greedy,
|
110 |
+
max_new_tokens=max_new_tokens,
|
111 |
+
)
|
112 |
+
history.append([prompt, response])
|
113 |
+
else:
|
114 |
+
history.append([prompt, ''])
|
115 |
+
for response, _ in self.process(
|
116 |
+
string=string,
|
117 |
+
greedy=greedy,
|
118 |
+
max_new_tokens=max_new_tokens,
|
119 |
+
stream=True
|
120 |
+
):
|
121 |
+
history[-1][-1] = response
|
122 |
+
yield '', history
|
123 |
+
return '', history
|
124 |
+
|
125 |
+
def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy, pbar=gr.Progress()):
|
126 |
+
string = get_prompt_llama2_format(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[])
|
127 |
+
if not self.config.stream_tokens_for_gradio:
|
128 |
+
response, _ = self.process(
|
129 |
+
string=string,
|
130 |
+
greedy=greedy,
|
131 |
+
max_new_tokens=max_new_tokens,
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
response = ''
|
135 |
+
for response, _ in self.process(
|
136 |
+
string=string,
|
137 |
+
greedy=greedy,
|
138 |
+
max_new_tokens=max_new_tokens,
|
139 |
+
stream=True
|
140 |
+
):
|
141 |
+
yield '', response
|
142 |
+
return '', response
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
|
146 |
+
configs = {
|
147 |
+
"repo_id": "meta-llama/Llama-2-7b-chat-hf",
|
148 |
+
"max_length": 4096,
|
149 |
+
"max_new_tokens": 4096,
|
150 |
+
"max_stream_tokens": 64,
|
151 |
+
"dtype": 'fp16',
|
152 |
+
"use_prefix_tokenizer": True
|
153 |
+
}
|
154 |
+
for key, value in configs.items():
|
155 |
+
print('\033[1;36m{:<30}\033[1;0m : {:>30}'.format(key.replace('_', ' '), f"{value}"))
|
156 |
+
|
157 |
+
|
158 |
+
server = InTimeDataFinderJaxServerLlama2Type.load_from_torch(
|
159 |
+
repo_id=configs['repo_id'],
|
160 |
+
config=configs
|
161 |
+
)
|
162 |
+
server.gradio_app_chat.launch(share=False)
|