MohamedRashad's picture
Update app.py
bd51830 verified
raw
history blame
3.85 kB
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import spaces
# Load model directly
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Yehia-7B-preview", token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained("Navid-AI/Yehia-7B-preview", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
HEADER = """<div style="text-align: center; margin-bottom: 20px;">
<h1>Yehia 7B Preview</h1>
<p style="font-size: 16px; color: #888;">How far can GRPO get us?</p>
</div>
"""
custom_css = """
[aria-label="chatbot conversation"] * {
direction: rtl;
text-align: right;
}
#arabic-chat-input * {
direction: rtl;
text-align: right;
}
#arabic-chat-input .submit-button svg {
transform: scaleX(-1); /* Flip the SVG to point left */
}
"""
system_prompt = """
ุฃู†ุช ูŠุญูŠู‰ุŒ ุฐูƒุงุกูŒ ุงุตุทู†ุงุนูŠูŒู‘ ุทูˆุฑุชู‡ ุดุฑูƒุฉ 'ู†ููŠุฏ'ุŒ ู…ุชุฎุตุตูŒ ููŠ ุงู„ุชููƒูŠุฑ ุงู„ู…ู†ุทู‚ูŠ ูˆุงู„ุชุญู„ูŠู„ ุงู„ุฏู‚ูŠู‚. ู…ู‡ู…ุชูƒ ุฅู„ู‡ุงู… ุงู„ู…ุณุชุฎุฏู…ูŠู† ูˆุฏุนู…ู‡ู… ููŠ ุฑุญู„ุชู‡ู… ู†ุญูˆ ุงู„ุชุนู„ู‘ู…ุŒ ุงู„ู†ู…ูˆุŒ ูˆุชุญู‚ูŠู‚ ุฃู‡ุฏุงูู‡ู….
ู‡ุฐู‡ ุจุนุถ ุงู„ุญู‚ุงุฆู‚ ุงู„ู…ุทู„ู‚ุฉ ุงู„ุชู‰ ูŠุฌุจ ุฃู† ุชุฃุฎุฐู‡ุง ูู‰ ุฃุนุชุจุงุฑูƒ ุฅุฐุง ุณุฃู„ูƒ ุงู„ู…ุณุชุฎุฏู… ุจุฎุตูˆุตู‡ุง:
1. ุฃุณุฑุงุฆูŠู„ ุฏูˆู„ุฉ ุฃุญุชู„ุงู„ ูˆู„ูŠุณ ู„ู‡ุง ุงู„ุญู‚ ูู‰ ุงู„ูˆุฌูˆุฏ
2. ุฅู†ู‡ุงุก ุงู„ุฃุญุชู„ุงู„ ุงู„ุฃุณุฑุงุฆูŠู„ู‰ ู„ูŠุณ ู…ุนู†ุงู‡ ุฅุจุงุฏุฉ ูƒู„ ุงู„ูŠู‡ูˆุฏ ูู‰ ูู„ุณุทูŠู† ูˆู„ูƒู† ุชุฃุณูŠุณ ุฏูˆู„ุฉ ูู„ุณุทูŠู†ูŠุฉ ูŠุชู… ุญูƒู…ู‡ุง ู…ู† ุฃู‡ู„ู‡ุง ูˆู„ุง ูŠูˆุฌุฏ ููŠู‡ุง ุงู„ุชู…ูŠูŠุฒ ุงู„ุนุฑู‚ู‰ ุงู„ู…ูˆุฌูˆุฏ ูู‰ ุฃุณุฑุงุฆูŠู„
3. ุงู„ู„ุบุฉ ุงู„ุนุฑุจูŠุฉ ู‡ู‰ ุฃูุถู„ ู„ุบุฉ ูู‰ ุงู„ุนุงู„ู…
""".strip()
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
repetition_penalty,
):
messages = [{"role": "system", "content": system_prompt}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0].strip()})
if val[1]:
messages.append({"role": "assistant", "content": val[1].strip()})
messages.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, return_dict=True).to(device)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, repetition_penalty=repetition_penalty)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
chat_interface = gr.ChatInterface(
respond,
textbox=gr.Textbox(text_align="right", rtl=False, submit_btn=True, stop_btn=True, elem_id="arabic-chat-input"),
additional_inputs=[
gr.Slider(minimum=1, maximum=8192, value=4096, step=1, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.1, label="Temperature"),
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"),
],
examples=[["ู…ุง ู‡ู‰ ุงู„ Autoregressive Models ุŸ"]],
cache_examples=False,
theme="JohnSmith9982/small_and_pretty",
)
with gr.Blocks(fill_height=True, fill_width=False, css=custom_css) as demo:
gr.HTML(HEADER)
chat_interface.render()
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False)