import json |
from tqdm import tqdm |
import re |
import fire |
def tokenize_caption(input_json: str, |
keep_punctuation: bool = False, |
host_address: str = None, |
character_level: bool = False, |
zh: bool = True, |
output_json: str = None): |
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold |
Args: |
input_json(string): Preprossessed json file. Structure like this: |
{ |
'audios': [ |
{ |
'audio_id': 'xxx', |
'captions': [ |
{ |
'caption': 'xxx', |
'cap_id': 'xxx' |
} |
] |
}, |
... |
] |
} |
threshold (int): Threshold to drop all words with counts < threshold |
keep_punctuation (bool): Includes or excludes punctuation. |
Returns: |
vocab (Vocab): Object with the processed vocabulary |
""" |
data = json.load(open(input_json, "r"))["audios"] |
if zh: |
from nltk.parse.corenlp import CoreNLPParser |
from zhon.hanzi import punctuation |
parser = CoreNLPParser(host_address) |
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): |
for cap_idx in range(len(data[audio_idx]["captions"])): |
caption = data[audio_idx]["captions"][cap_idx]["caption"] |
if not keep_punctuation: |
caption = re.sub("[{}]".format(punctuation), "", caption) |
if character_level: |
tokens = list(caption) |
else: |
tokens = list(parser.tokenize(caption)) |
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens) |
else: |
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer |
captions = {} |
for audio_idx in range(len(data)): |
audio_id = data[audio_idx]["audio_id"] |
captions[audio_id] = [] |
for cap_idx in range(len(data[audio_idx]["captions"])): |
caption = data[audio_idx]["captions"][cap_idx]["caption"] |
captions[audio_id].append({ |
"audio_id": audio_id, |
"id": cap_idx, |
"caption": caption |
}) |
tokenizer = PTBTokenizer() |
captions = tokenizer.tokenize(captions) |
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): |
audio_id = data[audio_idx]["audio_id"] |
for cap_idx in range(len(data[audio_idx]["captions"])): |
tokens = captions[audio_id][cap_idx] |
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens |
if output_json: |
json.dump( |
{ "audios": data }, open(output_json, "w"), |
indent=4, ensure_ascii=not zh) |
else: |
json.dump( |
{ "audios": data }, open(input_json, "w"), |
indent=4, ensure_ascii=not zh) |
if __name__ == "__main__": |
fire.Fire(tokenize_caption) |