Kevin Fink commited on
Commit
d1da5ff
·
1 Parent(s): b994095
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -25,7 +25,7 @@ model = get_peft_model(model, lora_config)
25
  model.gradient_checkpointing_enable()
26
 
27
  @spaces.GPU(duration=120)
28
- def fine_tune_model(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
29
  try:
30
  torch.cuda.empty_cache()
31
  def compute_metrics(eval_pred):
@@ -145,11 +145,15 @@ def predict(text):
145
  predictions = outputs.logits.argmax(dim=-1)
146
  return predictions.item()
147
  '''
 
 
 
 
148
  # Create Gradio interface
149
  try:
150
  model = AutoModelForSeq2SeqLM.from_pretrained('google/t5-efficient-tiny-nh8'.strip(), num_labels=2, force_download=True)
151
  iface = gr.Interface(
152
- fn=fine_tune_model,
153
  inputs=[
154
  gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
155
  gr.Textbox(label="HF hub to push to after training"),
 
25
  model.gradient_checkpointing_enable()
26
 
27
  @spaces.GPU(duration=120)
28
+ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
29
  try:
30
  torch.cuda.empty_cache()
31
  def compute_metrics(eval_pred):
 
145
  predictions = outputs.logits.argmax(dim=-1)
146
  return predictions.item()
147
  '''
148
+
149
+ def run_train(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
150
+ result = fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad)
151
+ return result
152
  # Create Gradio interface
153
  try:
154
  model = AutoModelForSeq2SeqLM.from_pretrained('google/t5-efficient-tiny-nh8'.strip(), num_labels=2, force_download=True)
155
  iface = gr.Interface(
156
+ fn=run_train,
157
  inputs=[
158
  gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
159
  gr.Textbox(label="HF hub to push to after training"),