Kevin Fink
commited on
Commit
·
451a63d
1
Parent(s):
0ee2b72
dev
Browse files
app.py
CHANGED
@@ -28,8 +28,9 @@ model_save_path = '/data/lora_finetuned_model' # Specify your desired save path
|
|
28 |
model.save_pretrained(model_save_path)
|
29 |
'''
|
30 |
|
31 |
-
def fine_tune_model(
|
32 |
try:
|
|
|
33 |
torch.cuda.empty_cache()
|
34 |
torch.nn.CrossEntropyLoss()
|
35 |
#rouge_metric = evaluate.load("rouge", cache_dir='/data/cache')
|
@@ -335,7 +336,7 @@ except Exception as e:
|
|
335 |
# Create Gradio interface
|
336 |
try:
|
337 |
iface = gr.Interface(
|
338 |
-
fn=
|
339 |
inputs=[
|
340 |
gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
|
341 |
gr.Textbox(label="HF hub to push to after training"),
|
|
|
28 |
model.save_pretrained(model_save_path)
|
29 |
'''
|
30 |
|
31 |
+
def fine_tune_model(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
|
32 |
try:
|
33 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-efficient-tiny-nh8")
|
34 |
torch.cuda.empty_cache()
|
35 |
torch.nn.CrossEntropyLoss()
|
36 |
#rouge_metric = evaluate.load("rouge", cache_dir='/data/cache')
|
|
|
336 |
# Create Gradio interface
|
337 |
try:
|
338 |
iface = gr.Interface(
|
339 |
+
fn=fine_tune_model,
|
340 |
inputs=[
|
341 |
gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
|
342 |
gr.Textbox(label="HF hub to push to after training"),
|