|
import json |
|
from tqdm import tqdm |
|
import re |
|
import fire |
|
|
|
|
|
def tokenize_caption(input_json: str, |
|
keep_punctuation: bool = False, |
|
host_address: str = None, |
|
character_level: bool = False, |
|
zh: bool = True, |
|
output_json: str = None): |
|
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold |
|
|
|
Args: |
|
input_json(string): Preprossessed json file. Structure like this: |
|
{ |
|
'audios': [ |
|
{ |
|
'audio_id': 'xxx', |
|
'captions': [ |
|
{ |
|
'caption': 'xxx', |
|
'cap_id': 'xxx' |
|
} |
|
] |
|
}, |
|
... |
|
] |
|
} |
|
threshold (int): Threshold to drop all words with counts < threshold |
|
keep_punctuation (bool): Includes or excludes punctuation. |
|
|
|
Returns: |
|
vocab (Vocab): Object with the processed vocabulary |
|
""" |
|
data = json.load(open(input_json, "r"))["audios"] |
|
|
|
if zh: |
|
from nltk.parse.corenlp import CoreNLPParser |
|
from zhon.hanzi import punctuation |
|
parser = CoreNLPParser(host_address) |
|
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): |
|
for cap_idx in range(len(data[audio_idx]["captions"])): |
|
caption = data[audio_idx]["captions"][cap_idx]["caption"] |
|
|
|
if not keep_punctuation: |
|
caption = re.sub("[{}]".format(punctuation), "", caption) |
|
if character_level: |
|
tokens = list(caption) |
|
else: |
|
tokens = list(parser.tokenize(caption)) |
|
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens) |
|
else: |
|
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer |
|
captions = {} |
|
for audio_idx in range(len(data)): |
|
audio_id = data[audio_idx]["audio_id"] |
|
captions[audio_id] = [] |
|
for cap_idx in range(len(data[audio_idx]["captions"])): |
|
caption = data[audio_idx]["captions"][cap_idx]["caption"] |
|
captions[audio_id].append({ |
|
"audio_id": audio_id, |
|
"id": cap_idx, |
|
"caption": caption |
|
}) |
|
tokenizer = PTBTokenizer() |
|
captions = tokenizer.tokenize(captions) |
|
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): |
|
audio_id = data[audio_idx]["audio_id"] |
|
for cap_idx in range(len(data[audio_idx]["captions"])): |
|
tokens = captions[audio_id][cap_idx] |
|
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens |
|
|
|
if output_json: |
|
json.dump( |
|
{ "audios": data }, open(output_json, "w"), |
|
indent=4, ensure_ascii=not zh) |
|
else: |
|
json.dump( |
|
{ "audios": data }, open(input_json, "w"), |
|
indent=4, ensure_ascii=not zh) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(tokenize_caption) |
|
|