Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -140,14 +140,14 @@ batch_size = 2
|
|
140 |
|
141 |
training_args = TrainingArguments(
|
142 |
#output_dir="alexkueck/test-tis-1",
|
143 |
-
output_dir="
|
144 |
overwrite_output_dir = 'True',
|
145 |
per_device_train_batch_size=batch_size, #batch_size = 2 for full training
|
146 |
per_device_eval_batch_size=batch_size,
|
147 |
evaluation_strategy = "steps", #oder
|
148 |
logging_strategy="steps", #oder epoch
|
149 |
logging_steps=10,
|
150 |
-
logging_dir='
|
151 |
learning_rate=2e-5,
|
152 |
weight_decay=0.01,
|
153 |
save_total_limit = 2,
|
@@ -195,7 +195,7 @@ print("Done Eval")
|
|
195 |
print("Save to ???")
|
196 |
login(token=os.environ["HF_WRITE_TOKEN"])
|
197 |
#trainer.save_model("test-tis-1")
|
198 |
-
trainer.save_model("
|
199 |
print("done")
|
200 |
|
201 |
#####################################
|
@@ -251,7 +251,7 @@ print("Output:\n" )
|
|
251 |
########################################
|
252 |
#mit der predict Funktion
|
253 |
print("Predict")
|
254 |
-
antwort = predict(model_neu, tokenizer, device_neu, prompt, [[
|
255 |
temperature=0.8,
|
256 |
max_length_tokens=1024,
|
257 |
max_context_length_tokens=2048,)
|
|
|
140 |
|
141 |
training_args = TrainingArguments(
|
142 |
#output_dir="alexkueck/test-tis-1",
|
143 |
+
output_dir="model",
|
144 |
overwrite_output_dir = 'True',
|
145 |
per_device_train_batch_size=batch_size, #batch_size = 2 for full training
|
146 |
per_device_eval_batch_size=batch_size,
|
147 |
evaluation_strategy = "steps", #oder
|
148 |
logging_strategy="steps", #oder epoch
|
149 |
logging_steps=10,
|
150 |
+
logging_dir='logs',
|
151 |
learning_rate=2e-5,
|
152 |
weight_decay=0.01,
|
153 |
save_total_limit = 2,
|
|
|
195 |
print("Save to ???")
|
196 |
login(token=os.environ["HF_WRITE_TOKEN"])
|
197 |
#trainer.save_model("test-tis-1")
|
198 |
+
trainer.save_model("model")
|
199 |
print("done")
|
200 |
|
201 |
#####################################
|
|
|
251 |
########################################
|
252 |
#mit der predict Funktion
|
253 |
print("Predict")
|
254 |
+
antwort = predict(model_neu, tokenizer, device_neu, prompt, [['Tis', '']], top_p=5,
|
255 |
temperature=0.8,
|
256 |
max_length_tokens=1024,
|
257 |
max_context_length_tokens=2048,)
|