Kevin Fink commited on
Commit
451a63d
·
1 Parent(s): 0ee2b72
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -28,8 +28,9 @@ model_save_path = '/data/lora_finetuned_model' # Specify your desired save path
28
  model.save_pretrained(model_save_path)
29
  '''
30
 
31
- def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
32
  try:
 
33
  torch.cuda.empty_cache()
34
  torch.nn.CrossEntropyLoss()
35
  #rouge_metric = evaluate.load("rouge", cache_dir='/data/cache')
@@ -335,7 +336,7 @@ except Exception as e:
335
  # Create Gradio interface
336
  try:
337
  iface = gr.Interface(
338
- fn=run_train,
339
  inputs=[
340
  gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
341
  gr.Textbox(label="HF hub to push to after training"),
 
28
  model.save_pretrained(model_save_path)
29
  '''
30
 
31
+ def fine_tune_model(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
32
  try:
33
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-efficient-tiny-nh8")
34
  torch.cuda.empty_cache()
35
  torch.nn.CrossEntropyLoss()
36
  #rouge_metric = evaluate.load("rouge", cache_dir='/data/cache')
 
336
  # Create Gradio interface
337
  try:
338
  iface = gr.Interface(
339
+ fn=fine_tune_model,
340
  inputs=[
341
  gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
342
  gr.Textbox(label="HF hub to push to after training"),