|
import os |
|
from threading import Thread |
|
from typing import Iterator |
|
|
|
import gradio as gr |
|
import torch |
|
import transformers |
|
from transformers import TextIteratorStreamer |
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) |
|
|
|
model_id = "microsoft/Orca-2-7b" |
|
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, device_map='auto') |
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_fast=False) |
|
|
|
system_message = "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior." |
|
user_message = "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?" |
|
|
|
DESCRIPTION = """ |
|
# Orca-2 7B |
|
This Space demonstrates model [Orca-2-7B](https://huggingface.co/microsoft/Orca-2-7B) by Microsoft, a Llama 2 derivate model with 7B parameters fine-tuned for sigle turn instructions. This space is running on Inference Endpoints using text-generation-inference library. If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://ui.endpoints.huggingface.co/). |
|
|
|
The system message is set to be the cautious system message: |
|
You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior. |
|
Feel free to modify it in the additional input section. The demo uses greedy decoding. |
|
|
|
π For more details about the Orca family of models take a look [at our blog post](https://msft.it/6042iGtzK). |
|
Note: Orca 2 is licensed under the [Microsoft Research License](LICENSE). Llama 2 is licensed under the [LLAMA 2 Community License](https://ai.meta.com/llama/license/). |
|
""" |
|
|
|
|
|
def to_prompt(conversations): |
|
text = "" |
|
for message in conversations: |
|
if message['role']!="assistant": |
|
text += f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>\n" |
|
else: |
|
text += f"<|im_start|>{message['role']}\n{message['content']}{tokenizer.eos_token}\n" |
|
prompt = text + "<|im_start|>assistant\n" |
|
inputs = tokenizer(prompt, return_tensors='pt').input_ids |
|
return inputs |
|
|
|
|
|
def generate( |
|
message: str, |
|
chat_history: list[tuple[str, str]], |
|
system_prompt: str, |
|
max_new_tokens: int = 1024, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
) -> Iterator[str]: |
|
conversation = [] |
|
if system_prompt: |
|
conversation.append({"role": "system", "content": system_prompt.strip()}) |
|
else: |
|
conversation.append({"role": "system", "content": ""}) |
|
for user, assistant in chat_history: |
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
conversation.append({"role": "user", "content": message}) |
|
|
|
input_ids = to_prompt(conversation) |
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") |
|
input_ids = input_ids.to(model.device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
{"input_ids": input_ids}, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=False, |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
yield "".join(outputs) |
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
fn=generate, |
|
additional_inputs=[ |
|
gr.Textbox(label="System prompt", lines=6, value=system_message), |
|
gr.Slider( |
|
label="Max new tokens", |
|
minimum=1, |
|
maximum=MAX_MAX_NEW_TOKENS, |
|
step=1, |
|
value=DEFAULT_MAX_NEW_TOKENS, |
|
), |
|
], |
|
stop_btn=None, |
|
examples=[ |
|
["How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?"], |
|
["The eighth-grade class held a bake-off. Kelsie made two times more cookies than Josh. Josh made one-fourth the number of cookies that Suzanne made. If Suzanne made 36 cookies, how many did Kelsie make?"], |
|
["Read the following web search snippets carefully and then answer the question below:\nWashington state remains near the top of the list for the most expensive average. According to the AAA, the current average price for a gallon of gas in Washington state is $5.01.\nToday's average price of gas in the U.S. is $3.82 per gallon, unchanged from yesterday, down $0.01 from last week and down $0.02 from last month.\n\nAnswer the following question:\n\nHow does the gas price in Washington compare to the national average? and what is the exact difference?"], |
|
["The ages of New Havens residents are 25.4% under the age of 18, 16.4% from 18 to 24, 31.2% from 25 to 44, 16.7% from 45 to 64, and 10.2% who were 65 years of age or older. The median age is 29 years, which is significantly lower than the national average. There are 91.8 males per 100 females. For every 100 females age 18 and over, there are 87.6 males.\n\nWhich gender group is larger: females or males?"], |
|
], |
|
) |
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
gr.Markdown(DESCRIPTION) |
|
chat_interface.render() |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20).launch() |
|
|