File size: 2,383 Bytes
fa3faa9
 
b829268
fa3faa9
 
 
 
045f1d7
 
fa3faa9
7def859
 
 
 
045f1d7
b829268
fa3faa9
7def859
045f1d7
fa3faa9
 
045f1d7
fa3faa9
 
 
 
045f1d7
fa3faa9
582d6c9
fa3faa9
045f1d7
fa3faa9
7def859
045f1d7
fa3faa9
582d6c9
fa3faa9
045f1d7
 
 
fa3faa9
b829268
 
 
 
fa3faa9
 
 
 
 
 
045f1d7
fa3faa9
 
be41edb
 
 
d26fe45
be41edb
 
 
 
 
 
 
 
 
d26fe45
 
 
 
be41edb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import csv
import pickle
from random import randint

### NOTICE: csv only accept two colomn input. but accept multi-time input.


# 1_2_3, 1 is action, 2 is supply object, 3 is source object
def update_dict_csv(term_dict:dict, f):
    for rows in csv.reader(f):
        word = rows[0].lower()
        if word in term_dict:
            if rows[1] not in term_dict[word]:
                term_dict[word] = term_dict[word]+[rows[1]]
            else:
                print("{},{} 已存在".format(word, rows[1]))
        else:
            term_dict[word]=[rows[1]]
    term_dict = sort_dict(term_dict)
    pass

def export_csv_dict(term_dict:dict, f):
    for key, val in term_dict.items():
        csv.writer(f).writerow([key, val])
    pass

def save_pickle_dict(term_dict:dict, f):
    pickle.dump(term_dict, f, pickle.HIGHEST_PROTOCOL)
    pass

def update_pickel_csv(pickle_f, csv_f):
    term_dict = pickle.load(pickle_f)
    update_dict_csv(term_dict, csv_f)
    #save to pickle file, highest protocol to get better performance
    pickle.dump(term_dict, pickle_f, pickle.HIGHEST_PROTOCOL)
    pass

def sort_dict(term_dict:dict):
    term_dict = dict(sorted(term_dict.items(), key=lambda x:len(x[0]), reverse=True))
    return term_dict

def get_word(term_dict:dict, key:str) -> str:
    word = term_dict[key][randint(0,len(term_dict[key])-1)]
    return word

#demo
term_dict_sc2 = {}
with open("./finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
    update_dict_csv(term_dict_sc2,f)

with open("../test.csv", "w", encoding='utf-8') as w:
    export_csv_dict(term_dict_sc2,w)

## for load pickle, just:
# pickle.load(f)


def form_dict(src_dict:list, tgt_dict:list) -> dict:
    final_dict = {}
    for idx, value in enumerate(src_dict):
        for item in value:
            final_dict.update({item:tgt_dict[idx]})
    return final_dict


class term_dict(dict):
    def __init__(self, path, src_lang, tgt_lang) -> None:
        with open(f"{path}/{src_lang}.csv", 'r', encoding="utf-8") as file:
            src_dict = list(csv.reader(file, delimiter=",")) 
        with open(f"{path}/{tgt_lang}.csv", 'r', encoding="utf-8") as file:
            tgt_dict = list(csv.reader(file, delimiter="," ))
        super().__init__(form_dict(src_dict, tgt_dict))


    def get(self, key:str) -> str:
        word = self[key][randint(0,len(self[key])-1)]
        return word