import EasyDel
import jax.lax
from EasyDel import JAXServer, get_mesh
from fjutils import get_float_dtype_by_name
from EasyDel.transform import llama_from_pretrained
from transformers import AutoTokenizer
import gradio as gr
from fjutils.tracker import initialise_tracking, get_mem
import argparse
from fjutils import make_shard_and_gather_fns, match_partition_rules
import threading
import typing
import IPython
import logging
import jax.numpy as jnp
import time
logging.basicConfig(
level=logging.INFO
)
instruct = 'Context:\n{context}\nQuestion:\nYes or No question, can you answer to ' \
'""{question}?"" only and only by using provided context?'
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer " \
"as helpfully as possible, while being safe. Your answers should not" \
" include any harmful, unethical, racist, sexist, toxic, dangerous, or " \
"illegal content. Please ensure that your responses are socially unbiased " \
"and positive in nature.\nIf a question does not make any sense, or is not " \
"factually coherent, explain why instead of answering something not correct. If " \
"you don't know the answer to a question, please don't share false information."
def get_prompt_llama2_format(message: str, chat_history,
system_prompt: str) -> str:
texts = [f'[INST] <>\n{system_prompt}\n<>\n\n']
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} [INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return ''.join(texts)
class InTimeDataFinderJaxServerLlama2Type(JAXServer):
def __init__(self, config=None):
super().__init__(config=config)
@classmethod
def load_from_torch(cls, repo_id, config=None):
with jax.default_device(jax.devices('cpu')[0]):
param, config_model = llama_from_pretrained(
repo_id
)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = EasyDel.FlaxLlamaForCausalLM(
config=config_model,
dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
precision=jax.lax.Precision('fastest'),
_do_init=False
)
return cls.load_from_params(
config_model=config_model,
model=model,
config=config,
params=param,
tokenizer=tokenizer,
add_param_field=True,
do_memory_log=False
)
@classmethod
def load_from_jax(cls, repo_id, checkpoint_path, config_repo=None, config=None):
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id, checkpoint_path)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
config_model = EasyDel.LlamaConfig.from_pretrained(config_repo or repo_id)
model = EasyDel.FlaxLlamaForCausalLM(
config=config_model,
dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
precision=jax.lax.Precision('fastest'),
_do_init=False
)
return cls.load(
path=path,
config_model=config_model,
model=model,
config=config,
tokenizer=tokenizer,
add_param_field=True,
do_memory_log=False
)
def process_gradio_chat(self, prompt, history, max_new_tokens, greedy, pbar=gr.Progress()):
string = get_prompt_llama2_format(
message=prompt,
chat_history=history,
system_prompt=DEFAULT_SYSTEM_PROMPT
)
if not self.config.stream_tokens_for_gradio:
response, _ = self.process(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
)
history.append([prompt, response])
else:
history.append([prompt, ''])
for response, _ in self.process(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
stream=True
):
history[-1][-1] = response
yield '', history
return '', history
def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy, pbar=gr.Progress()):
string = get_prompt_llama2_format(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[])
if not self.config.stream_tokens_for_gradio:
response, _ = self.process(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
)
else:
response = ''
for response, _ in self.process(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
stream=True
):
yield '', response
return '', response
if __name__ == "__main__":
configs = {
"repo_id": "meta-llama/Llama-2-7b-chat-hf",
"max_length": 4096,
"max_new_tokens": 4096,
"max_stream_tokens": 64,
"dtype": 'fp16',
"use_prefix_tokenizer": True
}
for key, value in configs.items():
print('\033[1;36m{:<30}\033[1;0m : {:>30}'.format(key.replace('_', ' '), f"{value}"))
server = InTimeDataFinderJaxServerLlama2Type.load_from_torch(
repo_id=configs['repo_id'],
config=configs
)
server.gradio_app_chat.launch(share=False)