ZoroaStrella commited on
Commit
f3f292e
·
1 Parent(s): e970aef

Add accelerate dependencies

Browse files
Files changed (2) hide show
  1. app.py +49 -123
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,136 +4,62 @@ import torch
4
 
5
  # Configuration
6
  MODEL_NAME = "RekaAI/reka-flash-3"
7
- DEFAULT_MAX_LENGTH = 4096 # Reduced for CPU efficiency
8
  DEFAULT_TEMPERATURE = 0.7
9
-
10
- # System prompt with reasoning instructions
11
- SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI.
12
- When responding, think step-by-step within <thinking> tags and conclude your answer after </thinking>.
13
- For example:
14
- User: What is 2+2?
15
- Assistant: <thinking>Let me calculate that. 2 plus 2 equals 4.</thinking> The answer is 4."""
16
-
17
- # Load model and tokenizer with 4-bit quantization
18
- try:
19
- quantization_config = BitsAndBytesConfig(
20
- load_in_4bit=True,
21
- bnb_4bit_compute_dtype=torch.float16,
22
- bnb_4bit_use_double_quant=True,
23
- bnb_4bit_quant_type="nf4"
24
- )
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
- model = AutoModelForCausalLM.from_pretrained(
27
- MODEL_NAME,
28
- quantization_config=quantization_config,
29
- device_map="auto", # Maps to CPU
30
- torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
31
  )
32
- tokenizer.pad_token = tokenizer.eos_token # Ensure padding works
33
- except Exception as e:
34
- raise Exception(f"Failed to load model: {str(e)}. Ensure access to {MODEL_NAME} and sufficient CPU memory.")
35
-
36
- def generate_response(
37
- message,
38
- chat_history,
39
- system_prompt,
40
- max_length,
41
- temperature,
42
- top_p,
43
- top_k,
44
- repetition_penalty,
45
- show_reasoning
46
- ):
47
- """Generate a response from Reka Flash-3 with reasoning tags."""
48
- try:
49
- # Format chat history and prompt (multi-round conversation)
50
- history_str = ""
51
- for user_msg, assistant_msg in chat_history:
52
- history_str += f"human: {user_msg} <sep> assistant: {assistant_msg} <sep> "
53
- prompt = f"{system_prompt} <sep> human: {message} <sep> assistant: <thinking>\n"
54
-
55
- # Tokenize input
56
- inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
57
-
58
- # Generate response with budget forcing
59
- outputs = model.generate(
60
- **inputs,
61
- max_new_tokens=max_length,
62
- temperature=temperature,
63
- top_p=top_p,
64
- top_k=top_k,
65
- repetition_penalty=repetition_penalty,
66
- do_sample=True,
67
- eos_token_id=tokenizer.convert_tokens_to_ids("<sep>"), # Stop at <sep>
68
- pad_token_id=tokenizer.eos_token_id
69
- )
70
-
71
- # Decode and clean response
72
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
- response = response[len(prompt):].split("<sep>")[0].strip() # Extract assistant response
74
-
75
- # Parse reasoning and final answer
76
- if "</thinking>" in response:
77
- reasoning, final_answer = response.split("</thinking>", 1)
78
- reasoning = reasoning.replace("<thinking>", "").strip()
79
- final_answer = final_answer.strip()
80
- else:
81
- reasoning = ""
82
- final_answer = response
83
-
84
- # Update chat history (drop reasoning to save tokens)
85
- chat_history.append({"role": "user", "content": message})
86
- chat_history.append({"role": "assistant", "content": final_answer})
87
-
88
- # Display reasoning if requested
89
- reasoning_display = f"**Reasoning:**\n{reasoning}" if show_reasoning and reasoning else ""
90
- return "", chat_history, reasoning_display
91
-
92
- except Exception as e:
93
- error_msg = f"Error: {str(e)}"
94
- gr.Warning(error_msg)
95
- return "", chat_history, error_msg
96
 
97
  # Gradio Interface
98
- with gr.Blocks(title="Reka Flash-3 Chat", theme=gr.themes.Soft()) as demo:
99
- gr.Markdown("""
100
- # Reka Flash-3 Chat Interface
101
- *Powered by [Reka AI](https://www.reka.ai/)* - A 21B parameter reasoning model optimized for CPU.
102
- """)
103
-
104
- with gr.Accordion("Deployment Instructions", open=True):
105
- gr.Textbox(
106
- value="""To deploy on Hugging Face Spaces:
107
- 1. Request access to RekaAI/reka-flash-3 from Reka AI.
108
- 2. Use a Pro subscription with zero-GPU (CPU-only) hardware.
109
- 3. Ensure 32GB+ CPU memory for 4-bit quantization.
110
- 4. Install dependencies: gradio, transformers, torch, bitsandbytes.""",
111
- label="How to Deploy",
112
- interactive=False
113
- )
114
-
115
  with gr.Row():
116
- chatbot = gr.Chatbot(type="messages", height=400, label="Conversation")
117
- reasoning_display = gr.Textbox(label="Model Reasoning", interactive=False, lines=8)
118
-
119
- with gr.Row():
120
- message = gr.Textbox(label="Your Message", placeholder="Ask me anything...", lines=2)
121
- submit_btn = gr.Button("Send", variant="primary")
122
-
123
- with gr.Accordion("Options", open=True):
124
- max_length = gr.Slider(128, 512, value=DEFAULT_MAX_LENGTH, label="Max Length", step=64)
125
- temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature", step=0.1)
126
- top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p", step=0.05)
127
- top_k = gr.Slider(1, 100, value=50, label="Top-k", step=1)
128
- repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty", step=0.1)
129
-
130
  system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT, lines=4)
131
- show_reasoning = gr.Checkbox(label="Show Reasoning", value=True)
132
 
133
- # Event handling
134
- inputs = [message, chatbot, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty, show_reasoning]
135
- outputs = [message, chatbot, reasoning_display]
136
  submit_btn.click(generate_response, inputs=inputs, outputs=outputs)
137
  message.submit(generate_response, inputs=inputs, outputs=outputs)
138
 
139
- demo.launch(debug=True)
 
4
 
5
  # Configuration
6
  MODEL_NAME = "RekaAI/reka-flash-3"
7
+ DEFAULT_MAX_LENGTH = 256
8
  DEFAULT_TEMPERATURE = 0.7
9
+ SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI."""
10
+
11
+ # Load model and tokenizer
12
+ quantization_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_compute_dtype=torch.float16,
15
+ bnb_4bit_use_double_quant=True,
16
+ bnb_4bit_quant_type="nf4"
17
+ )
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ MODEL_NAME,
21
+ quantization_config=quantization_config,
22
+ device_map="auto",
23
+ torch_dtype=torch.float16,
24
+ low_cpu_mem_usage=True
25
+ )
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+
28
+ def generate_response(message, chat_history, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty):
29
+ prompt = f"{system_prompt} <sep> human: {message} <sep> assistant: "
30
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
31
+ outputs = model.generate(
32
+ **inputs,
33
+ max_new_tokens=max_length,
34
+ temperature=temperature,
35
+ top_p=top_p,
36
+ top_k=top_k,
37
+ repetition_penalty=repetition_penalty,
38
+ do_sample=True,
39
+ pad_token_id=tokenizer.eos_token_id
40
  )
41
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("<sep>")[2].strip()
42
+ chat_history.append({"user": message, "assistant": response})
43
+ return "", chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Gradio Interface
46
+ with gr.Blocks(title="Reka Flash-3 Chat") as demo:
47
+ gr.Markdown("# Reka Flash-3 Chat Interface")
48
+ chatbot = gr.Chatbot(type="messages", height=400, label="Conversation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  with gr.Row():
50
+ message = gr.Textbox(label="Your Message", placeholder="Ask me anything...")
51
+ submit_btn = gr.Button("Send")
52
+ with gr.Accordion("Options", open=False):
53
+ max_length = gr.Slider(128, 512, value=DEFAULT_MAX_LENGTH, label="Max Length")
54
+ temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature")
55
+ top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p")
56
+ top_k = gr.Slider(1, 100, value=50, label="Top-k")
57
+ repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty")
 
 
 
 
 
 
58
  system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT, lines=4)
 
59
 
60
+ inputs = [message, chatbot, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty]
61
+ outputs = [message, chatbot]
 
62
  submit_btn.click(generate_response, inputs=inputs, outputs=outputs)
63
  message.submit(generate_response, inputs=inputs, outputs=outputs)
64
 
65
+ demo.launch()
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio>=3.50
2
  huggingface_hub==0.25.2
3
  torch
4
  transformers
5
- bitsandbytes
 
 
2
  huggingface_hub==0.25.2
3
  torch
4
  transformers
5
+ bitsandbytes
6
+ accelerate>=0.26.0