ZoroaStrella commited on
Commit
646a0c2
·
1 Parent(s): ce9b3a4

correct the model loading

Browse files
Files changed (2) hide show
  1. app.py +113 -99
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,15 +1,30 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
3
 
4
  # Configuration
5
  MODEL_NAME = "RekaAI/reka-flash-3"
6
  DEFAULT_MAX_LENGTH = 1024
7
  DEFAULT_TEMPERATURE = 0.7
8
 
9
- # System prompt
10
  SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI.
11
  Provide detailed, helpful answers while maintaining safety.
12
- Format responses clearly using markdown when appropriate."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_response(
15
  message,
@@ -24,69 +39,87 @@ def generate_response(
24
  frequency_penalty,
25
  show_reasoning
26
  ):
27
- # Format the prompt
28
- formatted_prompt = f"System: {system_prompt}\n\nUser: {message}\n\nAssistant:"
29
-
30
- # Create client
31
- client = InferenceClient()
32
-
33
- # Generate response
34
- response = client.text_generation(
35
- MODEL_NAME,
36
- prompt=formatted_prompt,
37
- max_new_tokens=max_length,
38
- temperature=temperature,
39
- top_p=top_p,
40
- top_k=top_k,
41
- repetition_penalty=repetition_penalty,
42
- presence_penalty=presence_penalty,
43
- frequency_penalty=frequency_penalty,
44
- details=show_reasoning,
45
- )
 
 
 
 
 
 
 
 
46
 
47
- # Extract reasoning and final answer if available
48
- reasoning = ""
49
- final_answer = response
50
- if show_reasoning and hasattr(response, 'details'):
51
- reasoning = response.details.get('reasoning', '')
52
- final_answer = response.generated_text
 
 
53
 
54
- # Update chat history
55
- chat_history.append((message, final_answer))
56
-
57
- # Create full history with reasoning
58
- full_history = list(chat_history)
59
- if show_reasoning and reasoning:
60
- full_history[-1] = (full_history[-1][0], f"{final_answer}\n\nREASONING:\n{reasoning}")
61
 
62
- return "", chat_history, reasoning if show_reasoning else ""
 
 
 
 
 
 
 
 
 
 
63
 
64
  # UI Components
65
  with gr.Blocks(title="Reka Flash-3 Chat Demo", theme=gr.themes.Soft()) as demo:
66
  # Header Section
67
- gr.Markdown(f"""
68
  # Reka Flash-3 Chat Interface
69
  *Powered by [Reka Core AI](https://www.reka.ai/)*
70
  """)
71
 
72
  # Deployment Notice
73
  with gr.Accordion("Important Deployment Notice", open=True):
74
- gr.Markdown(f"""
75
- **To deploy this model on Hugging Face Spaces:**
76
- 1. Request access to Reka Flash-3 from [Hugging Face Hub](https://huggingface.co/{MODEL_NAME})
77
- 2. Ensure you have Hugging Face PRO subscription
78
- 3. Add your HF token in Space settings
79
- 4. Set `GPU_SMALL` or higher in Space hardware settings
80
- """)
 
 
 
81
 
82
  # Chat Interface
83
  with gr.Row():
84
- chatbot = gr.Chatbot(height=500)
85
  reasoning_display = gr.Textbox(
86
  label="Model Reasoning",
87
  interactive=False,
88
  visible=True,
89
- lines=20,
90
  max_lines=20
91
  )
92
 
@@ -100,70 +133,51 @@ with gr.Blocks(title="Reka Flash-3 Chat Demo", theme=gr.themes.Soft()) as demo:
100
  )
101
  submit_btn = gr.Button("Send", variant="primary")
102
 
103
- # Parameters
104
- with gr.Accordion("Normal Options", open=False):
105
  with gr.Row():
106
- max_length = gr.Slider(128, 4096, value=DEFAULT_MAX_LENGTH, label="Max Length")
107
- temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature")
108
 
 
109
  with gr.Accordion("Advanced Options", open=False):
110
  with gr.Row():
111
- top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p")
112
- top_k = gr.Slider(1, 100, value=50, label="Top-k")
113
- repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty")
114
  with gr.Row():
115
- presence_penalty = gr.Slider(-2.0, 2.0, value=0.0, label="Presence Penalty")
116
- frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, label="Frequency Penalty")
117
-
118
  # System Prompt
119
  system_prompt = gr.Textbox(
120
  label="System Prompt",
121
  value=SYSTEM_PROMPT,
122
- lines=3
 
123
  )
124
-
125
  # Debug Options
126
- show_reasoning = gr.Checkbox(
127
- label="Show Model Reasoning",
128
- value=True
129
- )
130
 
131
  # Event Handling
132
- submit_btn.click(
133
- generate_response,
134
- inputs=[
135
- message,
136
- chatbot,
137
- system_prompt,
138
- max_length,
139
- temperature,
140
- top_p,
141
- top_k,
142
- repetition_penalty,
143
- presence_penalty,
144
- frequency_penalty,
145
- show_reasoning
146
- ],
147
- outputs=[message, chatbot, reasoning_display]
148
- )
149
-
150
- message.submit(
151
- generate_response,
152
- inputs=[
153
- message,
154
- chatbot,
155
- system_prompt,
156
- max_length,
157
- temperature,
158
- top_p,
159
- top_k,
160
- repetition_penalty,
161
- presence_penalty,
162
- frequency_penalty,
163
- show_reasoning
164
- ],
165
- outputs=[message, chatbot, reasoning_display]
166
- )
167
-
168
- # Deployment instructions
169
  demo.launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
  # Configuration
6
  MODEL_NAME = "RekaAI/reka-flash-3"
7
  DEFAULT_MAX_LENGTH = 1024
8
  DEFAULT_TEMPERATURE = 0.7
9
 
10
+ # System prompt with instructions for reasoning
11
  SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI.
12
  Provide detailed, helpful answers while maintaining safety.
13
+ Format responses clearly using markdown when appropriate.
14
+ When asked a question, think step by step inside <thinking> tags, then provide your final answer after </thinking> tags. For example:
15
+
16
+ User: What is 2+2?
17
+ Assistant: <thinking>
18
+ Let me calculate that. 2 plus 2 equals 4.
19
+ </thinking>
20
+ The answer is 4."""
21
+
22
+ # Load model and tokenizer (assuming CPU-only for zero GPU)
23
+ try:
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu", torch_dtype=torch.float32)
26
+ except Exception as e:
27
+ raise Exception(f"Failed to load model: {str(e)}. Ensure you have access to {MODEL_NAME} and sufficient CPU memory.")
28
 
29
  def generate_response(
30
  message,
 
39
  frequency_penalty,
40
  show_reasoning
41
  ):
42
+ """
43
+ Generate a response from Reka Flash-3, parsing reasoning and final answer.
44
+ """
45
+ try:
46
+ # Format the prompt with thinking tags
47
+ formatted_prompt = f"{system_prompt}\n\nUser: {message}\n\nAssistant: <thinking>\n"
48
+
49
+ # Tokenize input
50
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
51
+
52
+ # Generate response
53
+ outputs = model.generate(
54
+ **inputs,
55
+ max_new_tokens=max_length,
56
+ temperature=temperature,
57
+ top_p=top_p,
58
+ top_k=top_k,
59
+ repetition_penalty=repetition_penalty,
60
+ presence_penalty=presence_penalty,
61
+ frequency_penalty=frequency_penalty,
62
+ do_sample=True,
63
+ pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 0
64
+ )
65
+
66
+ # Decode the generated text
67
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
+ response = response[len(formatted_prompt):] # Remove the prompt from the output
69
 
70
+ # Parse reasoning and final answer
71
+ if "</thinking>" in response:
72
+ reasoning, final_answer = response.split("</thinking>", 1)
73
+ reasoning = reasoning.strip()
74
+ final_answer = final_answer.strip()
75
+ else:
76
+ reasoning = ""
77
+ final_answer = response.strip()
78
 
79
+ # Update chat history with final answer
80
+ chat_history.append((message, final_answer))
 
 
 
 
 
81
 
82
+ # Display reasoning if requested
83
+ reasoning_display = reasoning if show_reasoning and reasoning else ""
84
+ if reasoning_display:
85
+ reasoning_display = f"**Reasoning:**\n{reasoning_display}"
86
+
87
+ return "", chat_history, reasoning_display
88
+
89
+ except Exception as e:
90
+ error_msg = f"Error generating response: {str(e)}"
91
+ gr.Warning(error_msg)
92
+ return "", chat_history, error_msg
93
 
94
  # UI Components
95
  with gr.Blocks(title="Reka Flash-3 Chat Demo", theme=gr.themes.Soft()) as demo:
96
  # Header Section
97
+ gr.Markdown("""
98
  # Reka Flash-3 Chat Interface
99
  *Powered by [Reka Core AI](https://www.reka.ai/)*
100
  """)
101
 
102
  # Deployment Notice
103
  with gr.Accordion("Important Deployment Notice", open=True):
104
+ gr.Textbox(
105
+ value="""To deploy this model on Hugging Face Spaces:
106
+ 1. Request the Reka Flash-3 OSS model from Reka AI (https://www.reka.ai/).
107
+ 2. Use a Hugging Face Pro subscription for deployment.
108
+ 3. Configure your Space with zero GPU (CPU-only) hardware.
109
+ 4. Ensure sufficient CPU memory for the 3B parameter model.""",
110
+ label="Deployment Instructions",
111
+ lines=5,
112
+ interactive=False
113
+ )
114
 
115
  # Chat Interface
116
  with gr.Row():
117
+ chatbot = gr.Chatbot(height=500, label="Conversation")
118
  reasoning_display = gr.Textbox(
119
  label="Model Reasoning",
120
  interactive=False,
121
  visible=True,
122
+ lines=10,
123
  max_lines=20
124
  )
125
 
 
133
  )
134
  submit_btn = gr.Button("Send", variant="primary")
135
 
136
+ # Normal Options
137
+ with gr.Accordion("Normal Options", open=True):
138
  with gr.Row():
139
+ max_length = gr.Slider(128, 4096, value=DEFAULT_MAX_LENGTH, label="Max Length", step=128)
140
+ temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature", step=0.1)
141
 
142
+ # Advanced Options
143
  with gr.Accordion("Advanced Options", open=False):
144
  with gr.Row():
145
+ top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p", step=0.05)
146
+ top_k = gr.Slider(1, 100, value=50, label="Top-k", step=1)
147
+ repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty", step=0.1)
148
  with gr.Row():
149
+ presence_penalty = gr.Slider(-2.0, 2.0, value=0.0, label="Presence Penalty", step=0.1)
150
+ frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, label="Frequency Penalty", step=0.1)
151
+
152
  # System Prompt
153
  system_prompt = gr.Textbox(
154
  label="System Prompt",
155
  value=SYSTEM_PROMPT,
156
+ lines=5,
157
+ max_lines=10
158
  )
159
+
160
  # Debug Options
161
+ show_reasoning = gr.Checkbox(label="Show Model Reasoning", value=True)
 
 
 
162
 
163
  # Event Handling
164
+ inputs = [
165
+ message,
166
+ chatbot,
167
+ system_prompt,
168
+ max_length,
169
+ temperature,
170
+ top_p,
171
+ top_k,
172
+ repetition_penalty,
173
+ presence_penalty,
174
+ frequency_penalty,
175
+ show_reasoning
176
+ ]
177
+ outputs = [message, chatbot, reasoning_display]
178
+
179
+ submit_btn.click(generate_response, inputs=inputs, outputs=outputs)
180
+ message.submit(generate_response, inputs=inputs, outputs=outputs)
181
+
182
+ # Launch the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  demo.launch(debug=True)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  gradio>=3.50
2
- huggingface_hub==0.25.2
 
 
1
  gradio>=3.50
2
+ huggingface_hub==0.25.2
3
+ torch