Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
679a7b2
1
Parent(s):
1aab673
Download training data from S3
Browse files- app.py +7 -3
- scripts.py +11 -0
app.py
CHANGED
@@ -10,8 +10,10 @@ import gradio as gr
|
|
10 |
import scripts
|
11 |
import pandas
|
12 |
|
13 |
-
def
|
14 |
-
|
|
|
|
|
15 |
|
16 |
def train():
|
17 |
history = scripts.train_models('PlayfulTechnology')
|
@@ -19,8 +21,10 @@ def train():
|
|
19 |
|
20 |
|
21 |
with gr.Blocks() as trainer:
|
22 |
-
|
|
|
23 |
loss_plot = gr.Plot()
|
|
|
24 |
training_button.click(train,inputs=[],outputs=[loss_plot])
|
25 |
|
26 |
trainer.launch()
|
|
|
10 |
import scripts
|
11 |
import pandas
|
12 |
|
13 |
+
def download(button):
|
14 |
+
scripts.download_training_data()
|
15 |
+
return gr.Button.update(interactive=True)
|
16 |
+
|
17 |
|
18 |
def train():
|
19 |
history = scripts.train_models('PlayfulTechnology')
|
|
|
21 |
|
22 |
|
23 |
with gr.Blocks() as trainer:
|
24 |
+
download_button = gr.Button(value='Doenload training_data')
|
25 |
+
training_button = gr.Button(value="Train models",interactive=False)
|
26 |
loss_plot = gr.Plot()
|
27 |
+
download_button.click(download,inputs=download_button,outputs=training_button)
|
28 |
training_button.click(train,inputs=[],outputs=[loss_plot])
|
29 |
|
30 |
trainer.launch()
|
scripts.py
CHANGED
@@ -19,6 +19,7 @@ import scipy.spatial
|
|
19 |
import seaborn
|
20 |
import tqdm
|
21 |
import gradio
|
|
|
22 |
|
23 |
class SequenceCrossEntropyLoss(torch.nn.Module):
|
24 |
def __init__(self):
|
@@ -55,6 +56,16 @@ def clean_question(doc):
|
|
55 |
words.append('?')
|
56 |
return ''.join(words)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def prepare_wiki_qa(filename,outfilename):
|
59 |
data = pandas.read_csv(filename,sep='\t')
|
60 |
data['QNum']=data['QuestionID'].apply(lambda x: int(x[1:]))
|
|
|
19 |
import seaborn
|
20 |
import tqdm
|
21 |
import gradio
|
22 |
+
import boto3
|
23 |
|
24 |
class SequenceCrossEntropyLoss(torch.nn.Module):
|
25 |
def __init__(self):
|
|
|
56 |
words.append('?')
|
57 |
return ''.join(words)
|
58 |
|
59 |
+
def download_training_data():
|
60 |
+
if not os.path.exists('corpora'):
|
61 |
+
os.makedirs('corpora')
|
62 |
+
s3 = boto3.client('s3',
|
63 |
+
aws_access_key_id=os.environ['AWS_KEY'],
|
64 |
+
aws_secret_access_key=os.evviron['AWS_SECRET'])
|
65 |
+
for obj in s3.list_objects(Bucket='qarac')['Contents']:
|
66 |
+
filename = obj['Key']
|
67 |
+
s3.download_file('qarac',filename,'corpora/{}'.format(filename))
|
68 |
+
|
69 |
def prepare_wiki_qa(filename,outfilename):
|
70 |
data = pandas.read_csv(filename,sep='\t')
|
71 |
data['QNum']=data['QuestionID'].apply(lambda x: int(x[1:]))
|