Sephfox commited on
Commit
02c3e34
·
verified ·
1 Parent(s): 6a25926

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -17
app.py CHANGED
@@ -8,18 +8,24 @@ from huggingface_hub import HfApi
8
  import plotly.graph_objects as go
9
  import time
10
  from datetime import datetime
 
11
 
12
  # Cyberpunk and Loading Animation Styling
13
  def setup_cyberpunk_style():
14
  st.markdown("""
15
  <style>
16
- @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
17
- @import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap');
18
-
 
19
  .stApp {
20
  background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%);
21
  color: #00ff9d;
22
  font-family: 'Orbitron', sans-serif;
 
 
 
 
23
  }
24
 
25
  .main-title {
@@ -145,7 +151,9 @@ def initialize_model(model_name="gpt2"):
145
  # Load Dataset Function with Uploaded File Option
146
  def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
147
  if data_source == "demo":
148
- data = ["Sample text data for model training. This can be replaced with actual data for better performance."]
 
 
149
  elif uploaded_file is not None:
150
  if uploaded_file.name.endswith(".txt"):
151
  data = [uploaded_file.read().decode("utf-8")]
@@ -160,7 +168,7 @@ def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
160
  return dataset
161
 
162
  # Train Model Function with Customized Progress Bar
163
- def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4):
164
  training_args = TrainingArguments(
165
  output_dir="./results",
166
  overwrite_output_dir=True,
@@ -179,14 +187,26 @@ def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4):
179
  args=training_args,
180
  data_collator=data_collator,
181
  train_dataset=train_dataset,
 
182
  )
183
 
184
  trainer.train()
185
 
 
 
 
 
 
 
 
 
 
 
 
186
  # Main App Logic
187
  def main():
188
  setup_cyberpunk_style()
189
- st.markdown('<h1 class="main-title">Cyberpunk Neural Training Hub</h1>', unsafe_allow_html=True)
190
 
191
  # Initialize model and tokenizer
192
  model, tokenizer = initialize_model()
@@ -225,6 +245,15 @@ def main():
225
  # Load Dataset
226
  train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
227
 
 
 
 
 
 
 
 
 
 
228
  # Go Button to Start Training
229
  if st.button("Go"):
230
  progress_placeholder = st.empty()
@@ -233,22 +262,21 @@ def main():
233
 
234
  dashboard = TrainingDashboard()
235
 
236
- for epoch in range(training_epochs):
237
- loading_animation.markdown("""
238
- <div class="loading-animation"></div>
239
- """, unsafe_allow_html=True)
240
-
241
- train_model(model, train_dataset, tokenizer, epochs=1, batch_size=batch_size)
242
-
243
- # Update Progress Bar
244
- progress = (epoch + 1) / training_epochs * 100
245
  progress_placeholder.markdown(f"""
246
  <div class="progress-bar-container">
247
  <div class="progress-bar" style="width: {progress}%;"></div>
248
  </div>
249
  """, unsafe_allow_html=True)
250
-
251
- dashboard.update(loss=0, generation=epoch + 1, individual=batch_size)
 
 
 
 
 
 
252
 
253
  loading_animation.empty()
254
  st.success("Training Complete!")
 
8
  import plotly.graph_objects as go
9
  import time
10
  from datetime import datetime
11
+ import threading
12
 
13
  # Cyberpunk and Loading Animation Styling
14
  def setup_cyberpunk_style():
15
  st.markdown("""
16
  <style>
17
+ body, button, input, select, textarea {
18
+ font-family: 'Orbitron', sans-serif !important;
19
+ color: #00ff9d !important;
20
+ }
21
  .stApp {
22
  background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%);
23
  color: #00ff9d;
24
  font-family: 'Orbitron', sans-serif;
25
+ font-size: 16px;
26
+ line-height: 1.6;
27
+ padding: 20px;
28
+ box-sizing: border-box;
29
  }
30
 
31
  .main-title {
 
151
  # Load Dataset Function with Uploaded File Option
152
  def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
153
  if data_source == "demo":
154
+ data = ["In the neon-lit streets of Neo-Tokyo, a lone hacker fights against the oppressive megacorporations.",
155
+ "The rain falls in sheets, washing away the bloodstains from the alleyways.",
156
+ "She plugs into the matrix, seeking answers to questions that have haunted her for years."]
157
  elif uploaded_file is not None:
158
  if uploaded_file.name.endswith(".txt"):
159
  data = [uploaded_file.read().decode("utf-8")]
 
168
  return dataset
169
 
170
  # Train Model Function with Customized Progress Bar
171
+ def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4, progress_callback=None):
172
  training_args = TrainingArguments(
173
  output_dir="./results",
174
  overwrite_output_dir=True,
 
187
  args=training_args,
188
  data_collator=data_collator,
189
  train_dataset=train_dataset,
190
+ callbacks=[ProgressCallback(progress_callback)]
191
  )
192
 
193
  trainer.train()
194
 
195
+ class ProgressCallback(TrainerCallback):
196
+ def __init__(self, progress_callback):
197
+ super().__init__()
198
+ self.progress_callback = progress_callback
199
+
200
+ def on_epoch_end(self, args, state, control, **kwargs):
201
+ loss = state.log_history[-1]['loss']
202
+ generation = state.global_step // args.gradient_accumulation_steps + 1
203
+ individual = args.gradient_accumulation_steps
204
+ self.progress_callback(loss, generation, individual)
205
+
206
  # Main App Logic
207
  def main():
208
  setup_cyberpunk_style()
209
+ st.markdown('<h1 class="main-title">Neural Training Hub</h1>', unsafe_allow_html=True)
210
 
211
  # Initialize model and tokenizer
212
  model, tokenizer = initialize_model()
 
245
  # Load Dataset
246
  train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
247
 
248
+ # Chatbot Interaction
249
+ if st.checkbox("Enable Chatbot"):
250
+ user_input = st.text_input("You:", placeholder="Type your message here...")
251
+ if user_input:
252
+ inputs = tokenizer(user_input, return_tensors="pt")
253
+ outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1)
254
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
255
+ st.write("Bot:", response)
256
+
257
  # Go Button to Start Training
258
  if st.button("Go"):
259
  progress_placeholder = st.empty()
 
262
 
263
  dashboard = TrainingDashboard()
264
 
265
+ def train_progress(loss, generation, individual):
266
+ progress = (generation + 1) / dashboard.metrics['training_epochs'] * 100
 
 
 
 
 
 
 
267
  progress_placeholder.markdown(f"""
268
  <div class="progress-bar-container">
269
  <div class="progress-bar" style="width: {progress}%;"></div>
270
  </div>
271
  """, unsafe_allow_html=True)
272
+ dashboard.update(loss=loss, generation=generation, individual=individual)
273
+
274
+ thread = threading.Thread(target=train_model, args=(model, train_dataset, tokenizer, training_epochs, batch_size, train_progress))
275
+ thread.start()
276
+ loading_animation.markdown("""
277
+ <div class="loading-animation"></div>
278
+ """, unsafe_allow_html=True)
279
+ thread.join()
280
 
281
  loading_animation.empty()
282
  st.success("Training Complete!")