marinone94
commited on
Commit
·
1ced76b
1
Parent(s):
cc26ffa
add git init manual command in script
Browse files- run_speech_recognition_seq2seq_streaming.py +5 -0
- sm.py +3 -3
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -23,6 +23,7 @@ with 🤗 Datasets' streaming mode.
|
|
23 |
import json
|
24 |
import logging
|
25 |
import os
|
|
|
26 |
import sys
|
27 |
from dataclasses import dataclass, field
|
28 |
from typing import Any, Dict, List, Optional, Union
|
@@ -74,6 +75,10 @@ else:
|
|
74 |
wandb.login(key=wandb_token, relogin=True, timeout=5)
|
75 |
wandb.init(project="whisper", entity="pn-aa")
|
76 |
|
|
|
|
|
|
|
|
|
77 |
logger.info("Wandb API key set, logging to wandb")
|
78 |
|
79 |
@dataclass
|
|
|
23 |
import json
|
24 |
import logging
|
25 |
import os
|
26 |
+
import subprocess
|
27 |
import sys
|
28 |
from dataclasses import dataclass, field
|
29 |
from typing import Any, Dict, List, Optional, Union
|
|
|
75 |
wandb.login(key=wandb_token, relogin=True, timeout=5)
|
76 |
wandb.init(project="whisper", entity="pn-aa")
|
77 |
|
78 |
+
cmd = 'git init && git remote add origin && git pull origin main'
|
79 |
+
output = subprocess.run(cmd.split(), stdout=subprocess.PIPE)
|
80 |
+
print(output.stdout.decode())
|
81 |
+
|
82 |
logger.info("Wandb API key set, logging to wandb")
|
83 |
|
84 |
@dataclass
|
sm.py
CHANGED
@@ -32,7 +32,7 @@ sm_instances = test_sm_instances if TEST else full_sm_instances
|
|
32 |
|
33 |
ENTRY_POINT = "run_speech_recognition_seq2seq_streaming.py"
|
34 |
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh"
|
35 |
-
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-
|
36 |
if IMAGE_URI is None:
|
37 |
raise ValueError("IMAGE_URI variable not set, please update script.")
|
38 |
|
@@ -68,7 +68,6 @@ def parse_run_script():
|
|
68 |
line = line.split("=")
|
69 |
# remove '\t--'
|
70 |
key = str(line[0])
|
71 |
-
assert 0 < len(key) < 256, f"Key {key} is not allowed, len must be between 0 and 256"
|
72 |
try:
|
73 |
value = line[1]
|
74 |
except IndexError:
|
@@ -93,7 +92,7 @@ env_vars = {
|
|
93 |
"WANDB_TOKEN": os.environ.get("WANDB_TOKEN")
|
94 |
}
|
95 |
pprint(env_vars)
|
96 |
-
|
97 |
for sm_instance_name, sm_instance_values in sm_instances.items():
|
98 |
num_instances: int = \
|
99 |
int(sm_instance_values["num_instances"])
|
@@ -110,6 +109,7 @@ for sm_instance_name, sm_instance_values in sm_instances.items():
|
|
110 |
image_uri=IMAGE_URI,
|
111 |
hyperparameters=hyperparameters,
|
112 |
environment=env_vars,
|
|
|
113 |
)
|
114 |
hf_estimator.fit()
|
115 |
break
|
|
|
32 |
|
33 |
ENTRY_POINT = "run_speech_recognition_seq2seq_streaming.py"
|
34 |
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh"
|
35 |
+
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2"
|
36 |
if IMAGE_URI is None:
|
37 |
raise ValueError("IMAGE_URI variable not set, please update script.")
|
38 |
|
|
|
68 |
line = line.split("=")
|
69 |
# remove '\t--'
|
70 |
key = str(line[0])
|
|
|
71 |
try:
|
72 |
value = line[1]
|
73 |
except IndexError:
|
|
|
92 |
"WANDB_TOKEN": os.environ.get("WANDB_TOKEN")
|
93 |
}
|
94 |
pprint(env_vars)
|
95 |
+
repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}"
|
96 |
for sm_instance_name, sm_instance_values in sm_instances.items():
|
97 |
num_instances: int = \
|
98 |
int(sm_instance_values["num_instances"])
|
|
|
109 |
image_uri=IMAGE_URI,
|
110 |
hyperparameters=hyperparameters,
|
111 |
environment=env_vars,
|
112 |
+
git_config={"repo": repo, "branch": "main"},
|
113 |
)
|
114 |
hf_estimator.fit()
|
115 |
break
|