ZoroaStrella commited on
Commit
31a061f
·
1 Parent(s): f3f292e

Update for simplicity

Browse files
Files changed (1) hide show
  1. app.py +23 -52
app.py CHANGED
@@ -1,65 +1,36 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
- import torch
4
 
5
- # Configuration
6
- MODEL_NAME = "RekaAI/reka-flash-3"
7
- DEFAULT_MAX_LENGTH = 256
8
- DEFAULT_TEMPERATURE = 0.7
9
- SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI."""
10
 
11
- # Load model and tokenizer
12
- quantization_config = BitsAndBytesConfig(
13
- load_in_4bit=True,
14
- bnb_4bit_compute_dtype=torch.float16,
15
- bnb_4bit_use_double_quant=True,
16
- bnb_4bit_quant_type="nf4"
17
- )
18
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
- model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_NAME,
21
- quantization_config=quantization_config,
22
- device_map="auto",
23
- torch_dtype=torch.float16,
24
- low_cpu_mem_usage=True
25
- )
26
- tokenizer.pad_token = tokenizer.eos_token
27
-
28
- def generate_response(message, chat_history, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty):
29
- prompt = f"{system_prompt} <sep> human: {message} <sep> assistant: "
30
- inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
31
- outputs = model.generate(
32
- **inputs,
33
  max_new_tokens=max_length,
34
  temperature=temperature,
35
  top_p=top_p,
36
  top_k=top_k,
37
  repetition_penalty=repetition_penalty,
38
- do_sample=True,
39
- pad_token_id=tokenizer.eos_token_id
40
  )
41
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("<sep>")[2].strip()
42
- chat_history.append({"user": message, "assistant": response})
 
 
43
  return "", chat_history
44
 
45
- # Gradio Interface
46
- with gr.Blocks(title="Reka Flash-3 Chat") as demo:
47
- gr.Markdown("# Reka Flash-3 Chat Interface")
48
- chatbot = gr.Chatbot(type="messages", height=400, label="Conversation")
49
- with gr.Row():
50
- message = gr.Textbox(label="Your Message", placeholder="Ask me anything...")
51
- submit_btn = gr.Button("Send")
52
- with gr.Accordion("Options", open=False):
53
- max_length = gr.Slider(128, 512, value=DEFAULT_MAX_LENGTH, label="Max Length")
54
- temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature")
55
- top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p")
56
- top_k = gr.Slider(1, 100, value=50, label="Top-k")
57
- repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty")
58
- system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT, lines=4)
59
-
60
- inputs = [message, chatbot, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty]
61
- outputs = [message, chatbot]
62
- submit_btn.click(generate_response, inputs=inputs, outputs=outputs)
63
- message.submit(generate_response, inputs=inputs, outputs=outputs)
64
 
65
  demo.launch()
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import os
4
 
5
+ client = InferenceClient(model="RekaAI/reka-flash-3", token=os.getenv("HF_TOKEN"))
 
 
 
 
6
 
7
+ def generate_response(message, chat_history, system_prompt="You are a helpful assistant.",
8
+ max_length=512, temperature=0.7, top_p=0.9, top_k=50, repetition_penalty=1.0):
9
+ full_prompt = f"{system_prompt}\n\n"
10
+ for turn in chat_history:
11
+ full_prompt += f"{turn['role'].capitalize()}: {turn['content']}\n"
12
+ full_prompt += f"Human: {message}\nAssistant:"
13
+
14
+ response = client.text_generation(
15
+ full_prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  max_new_tokens=max_length,
17
  temperature=temperature,
18
  top_p=top_p,
19
  top_k=top_k,
20
  repetition_penalty=repetition_penalty,
21
+ stop_sequences=["\nHuman:", "\nAssistant:"]
 
22
  )
23
+
24
+ generated_text = response.strip()
25
+ chat_history.append({"role": "user", "content": message})
26
+ chat_history.append({"role": "assistant", "content": generated_text})
27
  return "", chat_history
28
 
29
+ with gr.Blocks() as demo:
30
+ chatbot = gr.Chatbot(type="messages")
31
+ msg = gr.Textbox()
32
+ clear = gr.Button("Clear")
33
+ msg.submit(generate_response, [msg, chatbot], [msg, chatbot])
34
+ clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  demo.launch()