Spaces:
Runtime error
Runtime error
add gpt data generation and analysis
Browse files- build_openprompt.py +19 -13
- central_finetuning.py +0 -0
- corenlp_openie.py +104 -0
- generation_test.py +101 -0
- gpt2_generation.py +32 -7
- gpt_api.py +27 -0
- monitor.sh +15 -0
- sft.py +4 -6
- trible.py +56 -0
build_openprompt.py
CHANGED
@@ -3,6 +3,9 @@ import pandas as pd
|
|
3 |
import json
|
4 |
import random
|
5 |
|
|
|
|
|
|
|
6 |
from tqdm import tqdm
|
7 |
|
8 |
|
@@ -12,31 +15,34 @@ samples = {
|
|
12 |
}
|
13 |
little = False
|
14 |
all_loaded_sample = 400000
|
15 |
-
|
16 |
s_pro = all_loaded_sample / 1e+7
|
17 |
# 读取概率
|
18 |
-
with open("./data/
|
19 |
csv_reader = csv.DictReader(f)
|
20 |
process_reader = tqdm(enumerate(csv_reader))
|
21 |
for row_number, row in process_reader:
|
22 |
num_samples = len(samples['x'])
|
23 |
process_reader.set_description(f"got data num: {num_samples}")
|
24 |
-
if
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
if little:
|
27 |
if len(samples["x"]) > 100:
|
28 |
break
|
29 |
-
if len(samples["x"]) > all_loaded_sample:
|
30 |
-
break
|
31 |
|
32 |
datum = row
|
33 |
-
prompt = datum['prompt']
|
34 |
-
|
35 |
-
if
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
label = prompt
|
41 |
x = prompt
|
42 |
# 小文本到大文本,因此x更小,同时x按照6:3:1的比例分配
|
|
|
3 |
import json
|
4 |
import random
|
5 |
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
|
8 |
+
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
|
|
|
15 |
}
|
16 |
little = False
|
17 |
all_loaded_sample = 400000
|
18 |
+
normal = True # 全部读取,非采样方式
|
19 |
s_pro = all_loaded_sample / 1e+7
|
20 |
# 读取概率
|
21 |
+
with open("./data/cleaned_oie_prompts.csv") as f:
|
22 |
csv_reader = csv.DictReader(f)
|
23 |
process_reader = tqdm(enumerate(csv_reader))
|
24 |
for row_number, row in process_reader:
|
25 |
num_samples = len(samples['x'])
|
26 |
process_reader.set_description(f"got data num: {num_samples}")
|
27 |
+
if not normal:
|
28 |
+
if random.uniform(0, 1) > s_pro:
|
29 |
+
continue
|
30 |
+
if len(samples["x"]) > all_loaded_sample:
|
31 |
+
break
|
32 |
+
else:
|
33 |
+
if row['prompt'] == "":
|
34 |
+
continue
|
35 |
if little:
|
36 |
if len(samples["x"]) > 100:
|
37 |
break
|
|
|
|
|
38 |
|
39 |
datum = row
|
40 |
+
# prompt = datum['prompt']
|
41 |
+
prompt = ",".join(eval(datum['raw_data'])['modifiers'])
|
42 |
+
if not normal:
|
43 |
+
modifiers = eval(datum['raw_data'])['modifiers']
|
44 |
+
if len(modifiers) < 4:
|
45 |
+
continue
|
|
|
46 |
label = prompt
|
47 |
x = prompt
|
48 |
# 小文本到大文本,因此x更小,同时x按照6:3:1的比例分配
|
central_finetuning.py
ADDED
File without changes
|
corenlp_openie.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import csv
|
4 |
+
import json
|
5 |
+
import jsonlines
|
6 |
+
from tqdm import tqdm
|
7 |
+
from stanfordcorenlp import StanfordCoreNLP
|
8 |
+
|
9 |
+
import concurrent.futures
|
10 |
+
|
11 |
+
|
12 |
+
nlp = StanfordCoreNLP('./stanford-corenlp-4.5.5')
|
13 |
+
|
14 |
+
SOURCE_FILE = "./data/raw_oie_source.jsonl"
|
15 |
+
|
16 |
+
def oie_extract(sentence):
|
17 |
+
output = nlp.annotate(sentence, properties={
|
18 |
+
'annotators': 'tokenize, ssplit, pos, depparse, parse, openie',
|
19 |
+
'outputFormat': 'json'
|
20 |
+
})
|
21 |
+
try:
|
22 |
+
data = json.loads(output)
|
23 |
+
sentences_ie = [i['openie'] for i in data['sentences'] if len(i['openie']) > 0]
|
24 |
+
oie_result = [max([sub["object"] for sub in sen], key=len) for sen in sentences_ie]
|
25 |
+
central_result = [sen[0]["subject"] for sen in sentences_ie][1:]
|
26 |
+
|
27 |
+
result = central_result + oie_result
|
28 |
+
result = ",".join(result)
|
29 |
+
except Exception as e:
|
30 |
+
print(f"An error occurred output: {output}")
|
31 |
+
result = ""
|
32 |
+
return result
|
33 |
+
|
34 |
+
def process_sentence(sentence):
|
35 |
+
row_data = {'raw_data': {'modifiers': sentence.split(".")}, 'prompt': ''}
|
36 |
+
oie_prompt = oie_extract(sentence)
|
37 |
+
row_data['prompt'] = oie_prompt
|
38 |
+
return row_data
|
39 |
+
|
40 |
+
def get_sentences(path):
|
41 |
+
if not os.path.exists(SOURCE_FILE):
|
42 |
+
raise FileNotFoundError(f"{SOURCE_FILE} not found.")
|
43 |
+
|
44 |
+
with jsonlines.open(path) as reader:
|
45 |
+
for obj in reader:
|
46 |
+
yield obj['description']
|
47 |
+
|
48 |
+
def main():
|
49 |
+
file_name = "./data/oie_prompts.csv"
|
50 |
+
fieldnames = ['prompt', 'raw_data']
|
51 |
+
csvfile = open(file_name, 'w', newline='', encoding='utf-8')
|
52 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
53 |
+
writer.writeheader()
|
54 |
+
|
55 |
+
# for sentence in tqdm(get_sentences(SOURCE_FILE), desc="extracting oie prompts"):
|
56 |
+
# row_data = {'raw_data': {'modifiers': sentence.split(".")}, "prompt": ""}
|
57 |
+
# oie_prompt = oie_extract(sentence)
|
58 |
+
# row_data['prompt'] = oie_prompt
|
59 |
+
# writer.writerow(row_data)
|
60 |
+
|
61 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
62 |
+
results = list(tqdm(executor.map(process_sentence, get_sentences(SOURCE_FILE)),
|
63 |
+
total=len(list(get_sentences(SOURCE_FILE))),
|
64 |
+
desc="extracting oie prompts"))
|
65 |
+
|
66 |
+
for result in results:
|
67 |
+
writer.writerow(result)
|
68 |
+
|
69 |
+
def remove_chinese(text):
|
70 |
+
pattern = re.compile(r'[\u4e00-\u9fa5]')
|
71 |
+
result = re.sub(pattern, '', text)
|
72 |
+
return result
|
73 |
+
|
74 |
+
|
75 |
+
def remove_special_chars(text):
|
76 |
+
pattern = re.compile(r'[^\w\s.,]')
|
77 |
+
result = re.sub(pattern, '', text)
|
78 |
+
return result
|
79 |
+
|
80 |
+
def cleaning_dataset():
|
81 |
+
"""只清理oie_prompts.csv,保存在cleaned_oie_prompts.csv中"""
|
82 |
+
file_name = "./data/cleaned_oie_prompts.csv"
|
83 |
+
fieldnames = ['prompt', 'raw_data']
|
84 |
+
csvfile = open(file_name, 'w', newline='', encoding='utf-8')
|
85 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
86 |
+
writer.writeheader()
|
87 |
+
with open("./data/oie_prompts.csv") as f:
|
88 |
+
csv_reader = csv.DictReader(f)
|
89 |
+
process_reader = tqdm(enumerate(csv_reader))
|
90 |
+
for row_number, row in process_reader:
|
91 |
+
datum = row
|
92 |
+
|
93 |
+
cleaned_prompts = remove_special_chars(remove_chinese(datum['prompt']))
|
94 |
+
joined_modifiers = ",".join(eval(datum['raw_data'])['modifiers'])
|
95 |
+
cleaned_modifiers = remove_special_chars(remove_chinese(joined_modifiers))
|
96 |
+
row_data = {'raw_data': {'modifiers': cleaned_modifiers.split(",")}, "prompt": cleaned_prompts}
|
97 |
+
writer.writerow(row_data)
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
# main()
|
102 |
+
cleaning_dataset()
|
103 |
+
|
104 |
+
|
generation_test.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import spacy
|
3 |
+
from accelerate import PartialState
|
4 |
+
from accelerate.utils import set_seed
|
5 |
+
|
6 |
+
from gpt2_generation import Translator
|
7 |
+
from gpt2_generation import generate_prompt, MODEL_CLASSES
|
8 |
+
|
9 |
+
os.environ["http_proxy"] = "http://127.0.0.1:7890"
|
10 |
+
os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
11 |
+
|
12 |
+
|
13 |
+
path_for_model = "./output/gpt2_openprompt/checkpoint-4500"
|
14 |
+
|
15 |
+
args = {
|
16 |
+
"model_type": "gpt2",
|
17 |
+
"model_name_or_path": path_for_model,
|
18 |
+
"length": 80,
|
19 |
+
"length_penalty": 1.2,
|
20 |
+
"stop_token": None,
|
21 |
+
"temperature": 1.0,
|
22 |
+
"repetition_penalty": 1.2,
|
23 |
+
"k": 3,
|
24 |
+
"p": 0.9,
|
25 |
+
"prefix": "",
|
26 |
+
"padding_text": "",
|
27 |
+
"xlm_language": "",
|
28 |
+
"seed": 42,
|
29 |
+
"use_cpu": False,
|
30 |
+
"num_return_sequences": 4,
|
31 |
+
"fp16": False,
|
32 |
+
"jit": False,
|
33 |
+
}
|
34 |
+
|
35 |
+
distributed_state = PartialState(cpu=args["use_cpu"])
|
36 |
+
|
37 |
+
if args["seed"] is not None:
|
38 |
+
set_seed(args["seed"])
|
39 |
+
|
40 |
+
tokenizer = None
|
41 |
+
model = None
|
42 |
+
zh_en_translator = None
|
43 |
+
nlp = None
|
44 |
+
|
45 |
+
def load_model_and_components():
|
46 |
+
global tokenizer, model, zh_en_translator, nlp
|
47 |
+
|
48 |
+
# Initialize the model and tokenizer
|
49 |
+
try:
|
50 |
+
args["model_type"] = args["model_type"].lower()
|
51 |
+
model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
|
52 |
+
except KeyError:
|
53 |
+
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
|
54 |
+
|
55 |
+
tokenizer = tokenizer_class.from_pretrained(args["model_name_or_path"], padding_side='left')
|
56 |
+
tokenizer.pad_token = tokenizer.eos_token
|
57 |
+
tokenizer.mask_token = tokenizer.eos_token
|
58 |
+
model = model_class.from_pretrained(args["model_name_or_path"])
|
59 |
+
print("Model loaded!")
|
60 |
+
|
61 |
+
# translator
|
62 |
+
zh_en_translator = Translator("Helsinki-NLP/opus-mt-zh-en")
|
63 |
+
print("Translator loaded!")
|
64 |
+
|
65 |
+
# filter
|
66 |
+
nlp = spacy.load('en_core_web_sm')
|
67 |
+
print("Filter loaded!")
|
68 |
+
|
69 |
+
# Set the model to the right device
|
70 |
+
model.to(distributed_state.device)
|
71 |
+
|
72 |
+
if args["fp16"]:
|
73 |
+
model.half()
|
74 |
+
|
75 |
+
def chat():
|
76 |
+
phrase = input("Input Prompt >>")
|
77 |
+
|
78 |
+
if tokenizer is None or model is None or zh_en_translator is None or nlp is None:
|
79 |
+
load_model_and_components()
|
80 |
+
|
81 |
+
messages = generate_prompt(
|
82 |
+
prompt_text=phrase,
|
83 |
+
args=args,
|
84 |
+
zh_en_translator=zh_en_translator,
|
85 |
+
nlp=nlp,
|
86 |
+
model=model,
|
87 |
+
tokenizer=tokenizer,
|
88 |
+
distributed_state=distributed_state,
|
89 |
+
)
|
90 |
+
|
91 |
+
for n, m in enumerate(messages):
|
92 |
+
print(f"-----generated sequence {n} -----")
|
93 |
+
print(m)
|
94 |
+
print("*"*60)
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
load_model_and_components()
|
100 |
+
while True:
|
101 |
+
chat()
|
gpt2_generation.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
# coding=utf-8
|
3 |
import inspect
|
4 |
import logging
|
|
|
5 |
from typing import Tuple
|
6 |
|
7 |
import torch
|
@@ -261,6 +262,26 @@ class _ModelFallbackWrapper(GenerationMixin):
|
|
261 |
return self._default._reorder_cache(past_key_values, beam_idx)
|
262 |
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
def generate_prompt(
|
265 |
prompt_text,
|
266 |
args,
|
@@ -326,6 +347,7 @@ def generate_prompt(
|
|
326 |
repeat_gen_time = repeat_gen_time + 1
|
327 |
generated_sequence = model.generate(
|
328 |
input_ids=input_ids,
|
|
|
329 |
max_length=args["length"] + len(encoded_prompt[0]),
|
330 |
temperature=args["temperature"],
|
331 |
top_k=args["k"],
|
@@ -352,13 +374,16 @@ def generate_prompt(
|
|
352 |
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
353 |
)
|
354 |
# no checking for prompt_text.
|
355 |
-
|
356 |
-
|
357 |
-
nouns =
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
|
|
|
|
|
|
362 |
generated_sequences.append(total_sequence)
|
363 |
|
364 |
return generated_sequences
|
|
|
2 |
# coding=utf-8
|
3 |
import inspect
|
4 |
import logging
|
5 |
+
import nltk
|
6 |
from typing import Tuple
|
7 |
|
8 |
import torch
|
|
|
262 |
return self._default._reorder_cache(past_key_values, beam_idx)
|
263 |
|
264 |
|
265 |
+
def remove_tokens_before_copula(text):
|
266 |
+
sentences = text.split(",")
|
267 |
+
result = [sentences[0]]
|
268 |
+
for sentence in sentences[1:]:
|
269 |
+
tokens = nltk.word_tokenize(sentence)
|
270 |
+
|
271 |
+
target_indices = [i for i, token in enumerate(tokens) if token.lower() in ["is", "are", "am"]]
|
272 |
+
|
273 |
+
if target_indices:
|
274 |
+
last_target_index = target_indices[-1]
|
275 |
+
result.append(tokens[last_target_index + 1:])
|
276 |
+
else:
|
277 |
+
result.append(tokens)
|
278 |
+
|
279 |
+
all_sentences = [" ".join(sen) for sen in result[1:]]
|
280 |
+
all_sentences.insert(0, result[0])
|
281 |
+
result_text = ",".join(all_sentences)
|
282 |
+
return result_text
|
283 |
+
|
284 |
+
|
285 |
def generate_prompt(
|
286 |
prompt_text,
|
287 |
args,
|
|
|
347 |
repeat_gen_time = repeat_gen_time + 1
|
348 |
generated_sequence = model.generate(
|
349 |
input_ids=input_ids,
|
350 |
+
length_penalty=args["length_penalty"],
|
351 |
max_length=args["length"] + len(encoded_prompt[0]),
|
352 |
temperature=args["temperature"],
|
353 |
top_k=args["k"],
|
|
|
374 |
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
375 |
)
|
376 |
# no checking for prompt_text.
|
377 |
+
# 暂时删去关键词检测
|
378 |
+
# docs = nlp(text)
|
379 |
+
# nouns = [token.text for token in docs if token.pos_ == 'NOUN']
|
380 |
+
# nouns = set(nouns)
|
381 |
+
# if nouns.intersection(FORBIDDEN_NOUN) and repeat_gen_time < 10:
|
382 |
+
# continue
|
383 |
+
# else:
|
384 |
+
# break
|
385 |
+
break
|
386 |
+
total_sequence = remove_tokens_before_copula(total_sequence)
|
387 |
generated_sequences.append(total_sequence)
|
388 |
|
389 |
return generated_sequences
|
gpt_api.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
|
3 |
+
|
4 |
+
def get_response_create_data(cn_text):
|
5 |
+
openai.api_type = "azure"
|
6 |
+
openai.api_base = "https://poster-pku-gpt4.openai.azure.com/"
|
7 |
+
openai.api_version = "2023-07-01-preview"
|
8 |
+
openai.api_key = '788c2b57f1954ddc92bb27786fbcdd6e'
|
9 |
+
|
10 |
+
response = openai.ChatCompletion.create(
|
11 |
+
engine="dragon",
|
12 |
+
messages=[{"role": "system", "content": "Now you are a home improvement designer,\
|
13 |
+
I give you some keywords, generate a brief interior design in English, no more than words: "},
|
14 |
+
{"role": "user", "content": cn_text}],
|
15 |
+
temperature=0.7,
|
16 |
+
max_tokens=800,
|
17 |
+
top_p=0.95,
|
18 |
+
frequency_penalty=0,
|
19 |
+
presence_penalty=0,
|
20 |
+
stop=None)
|
21 |
+
return response['choices'][0]["message"]["content"]
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == '__main__':
|
25 |
+
while (1):
|
26 |
+
input_text = input("输入:")
|
27 |
+
get_response_create_data(input_text)
|
monitor.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
while true; do
|
4 |
+
|
5 |
+
seed=$(date +%s)
|
6 |
+
|
7 |
+
python trible.py ${seed}
|
8 |
+
|
9 |
+
if [ $? -eq 0 ]; then
|
10 |
+
echo "program complect, no need to restart..."
|
11 |
+
break
|
12 |
+
else
|
13 |
+
echo "program crash, restarting"
|
14 |
+
fi
|
15 |
+
done
|
sft.py
CHANGED
@@ -14,7 +14,7 @@ from utils import (
|
|
14 |
get_dict_dataset,
|
15 |
get_advance_dataset,)
|
16 |
|
17 |
-
base_model = "
|
18 |
tokenizer, model = get_tok_and_model(f"./models/{base_model}")
|
19 |
tokenizer.pad_token = tokenizer.eos_token
|
20 |
rouge = evaluate.load("rouge")
|
@@ -53,18 +53,16 @@ print(f"data tokenize done. process time : {t2 - t1}")
|
|
53 |
|
54 |
|
55 |
training_args = TrainingArguments(
|
56 |
-
output_dir=f"./output/{base_model}
|
57 |
evaluation_strategy="steps",
|
58 |
eval_steps=20000,
|
59 |
-
learning_rate=
|
60 |
lr_scheduler_type="constant",
|
61 |
report_to="tensorboard",
|
62 |
per_device_train_batch_size=64,
|
63 |
per_device_eval_batch_size=32,
|
64 |
-
adam_beta1=0.9,
|
65 |
-
adam_beta2=0.98,
|
66 |
save_total_limit=1,
|
67 |
-
num_train_epochs=
|
68 |
fp16=True,
|
69 |
push_to_hub=False,
|
70 |
)
|
|
|
14 |
get_dict_dataset,
|
15 |
get_advance_dataset,)
|
16 |
|
17 |
+
base_model = "gpt2"
|
18 |
tokenizer, model = get_tok_and_model(f"./models/{base_model}")
|
19 |
tokenizer.pad_token = tokenizer.eos_token
|
20 |
rouge = evaluate.load("rouge")
|
|
|
53 |
|
54 |
|
55 |
training_args = TrainingArguments(
|
56 |
+
output_dir=f"./output/{base_model}_openprompt",
|
57 |
evaluation_strategy="steps",
|
58 |
eval_steps=20000,
|
59 |
+
learning_rate=3e-5,
|
60 |
lr_scheduler_type="constant",
|
61 |
report_to="tensorboard",
|
62 |
per_device_train_batch_size=64,
|
63 |
per_device_eval_batch_size=32,
|
|
|
|
|
64 |
save_total_limit=1,
|
65 |
+
num_train_epochs=60,
|
66 |
fp16=True,
|
67 |
push_to_hub=False,
|
68 |
)
|
trible.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import click
|
3 |
+
import random
|
4 |
+
import jsonlines
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
from gpt_api import get_response_create_data
|
8 |
+
|
9 |
+
|
10 |
+
KEYWORDS_PATH = "/data/aigc/zw/task2/pg_distilgpt/data/raw_keywords.txt"
|
11 |
+
TARGET_PATH = "/data/aigc/zw/task2/pg_distilgpt/data/raw_discriptions.jsonl"
|
12 |
+
|
13 |
+
if not os.path.exists(TARGET_PATH):
|
14 |
+
with open(TARGET_PATH, "w") as f:
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
def read_keywords(path=KEYWORDS_PATH):
|
19 |
+
|
20 |
+
keywords = []
|
21 |
+
|
22 |
+
with open(path, 'r', encoding='utf-8') as file:
|
23 |
+
for line in tqdm(file, desc="reading keywords"):
|
24 |
+
parts = line.strip().split('\t')
|
25 |
+
result = parts[0]
|
26 |
+
keywords.append(result)
|
27 |
+
|
28 |
+
return keywords
|
29 |
+
|
30 |
+
def keywords_sampler(num, key_words):
|
31 |
+
random.seed()
|
32 |
+
while(1):
|
33 |
+
sampled_words = random.sample(key_words, num)
|
34 |
+
yield sampled_words
|
35 |
+
|
36 |
+
def create_data(keywords, total_num=10000, n=4, seed=42):
|
37 |
+
random.seed(seed)
|
38 |
+
for n, key_words in tqdm(enumerate(keywords_sampler(n, keywords)), desc="generating data"):
|
39 |
+
|
40 |
+
res = get_response_create_data(" ".join(key_words))
|
41 |
+
|
42 |
+
with jsonlines.open(TARGET_PATH, mode='a') as writer:
|
43 |
+
writer.write({"keywrods": key_words, "description": res})
|
44 |
+
|
45 |
+
if n >= total_num:
|
46 |
+
print("generation data done.")
|
47 |
+
break
|
48 |
+
|
49 |
+
@click.command()
|
50 |
+
@click.argument('seed', type=int)
|
51 |
+
def main(seed):
|
52 |
+
keywords = read_keywords()
|
53 |
+
create_data(keywords, seed=seed)
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
main()
|