Upload 396 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +22 -0
- fengshen/API/main.py +76 -0
- fengshen/API/text_classification.json +46 -0
- fengshen/API/utils.py +167 -0
- fengshen/README.md +105 -0
- fengshen/__init__.py +19 -0
- fengshen/cli/fengshen_pipeline.py +34 -0
- fengshen/data/__init__.py +1 -0
- fengshen/data/bert_dataloader/auto_split.sh +10 -0
- fengshen/data/bert_dataloader/load.py +200 -0
- fengshen/data/bert_dataloader/preprocessing.py +110 -0
- fengshen/data/clip_dataloader/flickr.py +105 -0
- fengshen/data/data_utils/common_utils.py +4 -0
- fengshen/data/data_utils/mask_utils.py +285 -0
- fengshen/data/data_utils/sentence_split.py +35 -0
- fengshen/data/data_utils/sop_utils.py +32 -0
- fengshen/data/data_utils/token_type_utils.py +25 -0
- fengshen/data/data_utils/truncate_utils.py +19 -0
- fengshen/data/dreambooth_datasets/dreambooth_datasets.py +183 -0
- fengshen/data/hubert/hubert_dataset.py +361 -0
- fengshen/data/megatron_dataloader/Makefile +9 -0
- fengshen/data/megatron_dataloader/__init__.py +1 -0
- fengshen/data/megatron_dataloader/bart_dataset.py +443 -0
- fengshen/data/megatron_dataloader/bert_dataset.py +196 -0
- fengshen/data/megatron_dataloader/blendable_dataset.py +64 -0
- fengshen/data/megatron_dataloader/dataset_utils.py +788 -0
- fengshen/data/megatron_dataloader/helpers.cpp +794 -0
- fengshen/data/megatron_dataloader/indexed_dataset.py +585 -0
- fengshen/data/megatron_dataloader/utils.py +24 -0
- fengshen/data/mmap_dataloader/mmap_datamodule.py +68 -0
- fengshen/data/mmap_dataloader/mmap_index_dataset.py +53 -0
- fengshen/data/preprocess.py +1 -0
- fengshen/data/sequence_tagging_dataloader/sequence_tagging_collator.py +274 -0
- fengshen/data/sequence_tagging_dataloader/sequence_tagging_datasets.py +116 -0
- fengshen/data/t5_dataloader/t5_datasets.py +562 -0
- fengshen/data/t5_dataloader/t5_gen_datasets.py +391 -0
- fengshen/data/taiyi_stable_diffusion_datasets/taiyi_datasets.py +173 -0
- fengshen/data/task_dataloader/__init__.py +3 -0
- fengshen/data/task_dataloader/medicalQADataset.py +137 -0
- fengshen/data/task_dataloader/task_datasets.py +206 -0
- fengshen/data/universal_datamodule/__init__.py +4 -0
- fengshen/data/universal_datamodule/universal_datamodule.py +165 -0
- fengshen/data/universal_datamodule/universal_sampler.py +125 -0
- fengshen/examples/DAVAE/generate.py +36 -0
- fengshen/examples/FastDemo/README.md +105 -0
- fengshen/examples/FastDemo/YuyuanQA.py +71 -0
- fengshen/examples/FastDemo/image/demo.png +0 -0
- fengshen/examples/GAVAE/generate.py +23 -0
- fengshen/examples/PPVAE/generate.py +24 -0
- fengshen/examples/classification/demo_classification_afqmc_erlangshen_offload.sh +103 -0
.gitattributes
CHANGED
@@ -32,3 +32,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
fengshen/examples/finetune_taiyi_stable_diffusion/demo_dataset/part_0/00000003.jpg filter=lfs diff=lfs merge=lfs -text
|
36 |
+
fengshen/examples/stable_diffusion_chinese_EN/result_examples/cat_eating_guoqiao_noodle.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
fengshen/examples/stable_diffusion_chinese_EN/result_examples/huskiy_wearing_space_suit.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
fengshen/examples/stable_diffusion_chinese_EN/result_examples/xiaoqiao_oil_painting.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
fengshen/examples/stable_diffusion_chinese_EN/result_examples/xiaoqiao_vangogh.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号4k壁纸.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号4k壁纸384.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号4k壁纸复杂.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号4k壁纸高清.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号4k壁纸精细.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号插画.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号水彩.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号素描.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上英文逗号油画.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上中文逗号.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上中文感叹号.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上中文句号.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上nega广告.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上nega广告符号.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
fengshen/examples/stable_diffusion_chinese/img/日出,海面上nega广告符号词汇.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
fengshen/examples/stable_diffusion_dreambooth/duck_result.png filter=lfs diff=lfs merge=lfs -text
|
fengshen/API/main.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
import click
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
from importlib import import_module
|
6 |
+
from fastapi import FastAPI, WebSocket
|
7 |
+
from starlette.middleware.cors import CORSMiddleware
|
8 |
+
from utils import user_config, api_logger, setup_logger, RequestDataStructure
|
9 |
+
|
10 |
+
# 命令行启动时只输入一个参数,即配置文件的名字,eg: text_classification.json
|
11 |
+
# 其余所有配置在该配置文件中设定,不在命令行中指定
|
12 |
+
total_parser = argparse.ArgumentParser("API")
|
13 |
+
total_parser.add_argument("config_path", type=str)
|
14 |
+
args = total_parser.parse_args()
|
15 |
+
|
16 |
+
# set up user config
|
17 |
+
user_config.setup_config(args)
|
18 |
+
|
19 |
+
# set up logger
|
20 |
+
setup_logger(api_logger, user_config)
|
21 |
+
|
22 |
+
# load pipeline
|
23 |
+
pipeline_class = getattr(import_module('fengshen.pipelines.' + user_config.pipeline_type), 'Pipeline')
|
24 |
+
model_settings = user_config.model_settings
|
25 |
+
model_args = argparse.Namespace(**model_settings)
|
26 |
+
pipeline = pipeline_class(
|
27 |
+
args = model_args,
|
28 |
+
model = user_config.model_name
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
# initialize app
|
33 |
+
app = FastAPI(
|
34 |
+
title = user_config.PROJECT_NAME,
|
35 |
+
openapi_url = f"{user_config.API_PREFIX_STR}/openapi.json"
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
# api
|
40 |
+
# TODO
|
41 |
+
# 需要针对不同请求方法做不同判断,目前仅跑通了较通用的POST方法
|
42 |
+
# POST方法可以完成大多数 输入文本-返回结果 的请求任务
|
43 |
+
if(user_config.API_method == "POST"):
|
44 |
+
@app.post(user_config.API_path, tags = user_config.API_tags)
|
45 |
+
async def fengshen_post(data:RequestDataStructure):
|
46 |
+
# logging
|
47 |
+
api_logger.info(data.input_text)
|
48 |
+
|
49 |
+
input_text = data.input_text
|
50 |
+
|
51 |
+
result = pipeline(input_text)
|
52 |
+
|
53 |
+
return result
|
54 |
+
else:
|
55 |
+
print("only support POST method")
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
# Set all CORS enabled origins
|
60 |
+
if user_config.BACKEND_CORS_ORIGINS:
|
61 |
+
app.add_middleware(
|
62 |
+
CORSMiddleware,
|
63 |
+
allow_origins = [str(origin) for origin in user_config.BACKEND_CORS_ORIGINS],
|
64 |
+
allow_credentials = user_config.allow_credentials,
|
65 |
+
allow_methods = user_config.allow_methods,
|
66 |
+
allow_headers = user_config.allow_headers,
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
|
72 |
+
# 启动后可在浏览器打开 host:port/docs 查看接口的具体信息,并可进行简单测试
|
73 |
+
# eg: 127.0.0.1:8990/docs
|
74 |
+
uvicorn.run(app, host = user_config.SERVER_HOST, port = user_config.SERVER_PORT)
|
75 |
+
|
76 |
+
|
fengshen/API/text_classification.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"SERVER": {
|
3 |
+
"SERVER_HOST": "127.0.0.1",
|
4 |
+
"SERVER_PORT": 8990,
|
5 |
+
"SERVER_NAME": "fengshen_demo",
|
6 |
+
"PROJECT_NAME": "fengshen_demo",
|
7 |
+
"API_PREFIX_STR": "/api",
|
8 |
+
|
9 |
+
"API_method" : "POST",
|
10 |
+
"API_path": "/TextClassification",
|
11 |
+
"API_tags": ["TextClassification"],
|
12 |
+
|
13 |
+
"BACKEND_CORS_ORIGINS": ["*"],
|
14 |
+
"allow_credentials": true,
|
15 |
+
"allow_methods": ["*"],
|
16 |
+
"allow_headers": ["*"]
|
17 |
+
|
18 |
+
},
|
19 |
+
"LOGGING": {
|
20 |
+
"log_file_path": "",
|
21 |
+
"log_level": "INFO"
|
22 |
+
},
|
23 |
+
|
24 |
+
"PIPELINE": {
|
25 |
+
"pipeline_type": "text_classification",
|
26 |
+
"model_name": "IDEA-CCNL/Erlangshen-Roberta-110M-Similarity",
|
27 |
+
"model_settings": {
|
28 |
+
"device": -1,
|
29 |
+
"texta_name": "sentence",
|
30 |
+
"textb_name": "sentence2",
|
31 |
+
"label_name": "label",
|
32 |
+
"max_length": 512,
|
33 |
+
"return_tensors": "pt",
|
34 |
+
"padding": "longest",
|
35 |
+
"truncation": true,
|
36 |
+
"skip_special_tokens": true,
|
37 |
+
"clean_up_tkenization_spaces": true,
|
38 |
+
|
39 |
+
"skip_steps": 10,
|
40 |
+
"clip_guidance_scale": 7500,
|
41 |
+
"init_scale": 10
|
42 |
+
}
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
|
fengshen/API/utils.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from argparse import Namespace
|
6 |
+
from typing import List, Literal, Optional, Union
|
7 |
+
from pydantic import AnyHttpUrl, BaseSettings, HttpUrl, validator, BaseModel
|
8 |
+
|
9 |
+
|
10 |
+
CURRENT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
|
12 |
+
# request body
|
13 |
+
# 使用pydantic对请求中的body数据进行验证
|
14 |
+
class RequestDataStructure(BaseModel):
|
15 |
+
input_text: List[str] = [""]
|
16 |
+
uuid: Optional[int]
|
17 |
+
|
18 |
+
# parameters for text2image model
|
19 |
+
input_image: Optional[str]
|
20 |
+
skip_steps: Optional[int]
|
21 |
+
clip_guidance_scale: Optional[int]
|
22 |
+
init_scale: Optional[int]
|
23 |
+
|
24 |
+
# API config
|
25 |
+
@dataclass
|
26 |
+
class APIConfig:
|
27 |
+
|
28 |
+
# server config
|
29 |
+
SERVER_HOST: AnyHttpUrl = "127.0.0.1"
|
30 |
+
SERVER_PORT: int = 8990
|
31 |
+
SERVER_NAME: str = ""
|
32 |
+
PROJECT_NAME: str = ""
|
33 |
+
API_PREFIX_STR: str = "/api"
|
34 |
+
|
35 |
+
# api config
|
36 |
+
API_method: Literal["POST","GET","PUT","OPTIONS","WEBSOCKET","PATCH","DELETE","TRACE","CONNECT"] = "POST"
|
37 |
+
API_path: str = "/TextClassification"
|
38 |
+
API_tags: List[str] = field(default_factory = lambda: [""])
|
39 |
+
|
40 |
+
# CORS config
|
41 |
+
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = field(default_factory = lambda: ["*"])
|
42 |
+
allow_credentials: bool = True
|
43 |
+
allow_methods: List[str] = field(default_factory = lambda: ["*"])
|
44 |
+
allow_headers: List[str] = field(default_factory = lambda: ["*"])
|
45 |
+
|
46 |
+
# log config
|
47 |
+
log_file_path: str = ""
|
48 |
+
log_level: str = "INFO"
|
49 |
+
|
50 |
+
# pipeline config
|
51 |
+
pipeline_type: str = ""
|
52 |
+
model_name: str = ""
|
53 |
+
|
54 |
+
# model config
|
55 |
+
# device: int = -1
|
56 |
+
# texta_name: Optional[str] = "sentence"
|
57 |
+
# textb_name: Optional[str] = "sentence2"
|
58 |
+
# label_name: Optional[str] = "label"
|
59 |
+
# max_length: int = 512
|
60 |
+
# return_tensors: str = "pt"
|
61 |
+
# padding: str = "longest"
|
62 |
+
# truncation: bool = True
|
63 |
+
# skip_special_tokens: bool = True
|
64 |
+
# clean_up_tkenization_spaces: bool = True
|
65 |
+
|
66 |
+
# # parameters for text2image model
|
67 |
+
# skip_steps: Optional[int] = 0
|
68 |
+
# clip_guidance_scale: Optional[int] = 0
|
69 |
+
# init_scale: Optional[int] = 0
|
70 |
+
|
71 |
+
def setup_config(self, args:Namespace) -> None:
|
72 |
+
|
73 |
+
# load config file
|
74 |
+
with open(CURRENT_DIR_PATH + "/" + args.config_path, "r") as jsonfile:
|
75 |
+
config = json.load(jsonfile)
|
76 |
+
|
77 |
+
server_config = config["SERVER"]
|
78 |
+
logging_config = config["LOGGING"]
|
79 |
+
pipeline_config = config["PIPELINE"]
|
80 |
+
|
81 |
+
# server config
|
82 |
+
self.SERVER_HOST: AnyHttpUrl = server_config["SERVER_HOST"]
|
83 |
+
self.SERVER_PORT: int = server_config["SERVER_PORT"]
|
84 |
+
self.SERVER_NAME: str = server_config["SERVER_NAME"]
|
85 |
+
self.PROJECT_NAME: str = server_config["PROJECT_NAME"]
|
86 |
+
self.API_PREFIX_STR: str = server_config["API_PREFIX_STR"]
|
87 |
+
|
88 |
+
# api config
|
89 |
+
self.API_method: Literal["POST","GET","PUT","OPTIONS","WEBSOCKET","PATCH","DELETE","TRACE","CONNECT"] = server_config["API_method"]
|
90 |
+
self.API_path: str = server_config["API_path"]
|
91 |
+
self.API_tags: List[str] = server_config["API_tags"]
|
92 |
+
|
93 |
+
# CORS config
|
94 |
+
self.BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = server_config["BACKEND_CORS_ORIGINS"]
|
95 |
+
self.allow_credentials: bool = server_config["allow_credentials"]
|
96 |
+
self.allow_methods: List[str] = server_config["allow_methods"]
|
97 |
+
self.allow_headers: List[str] = server_config["allow_headers"]
|
98 |
+
|
99 |
+
# log config
|
100 |
+
self.log_file_path: str = logging_config["log_file_path"]
|
101 |
+
self.log_level: str = logging_config["log_level"]
|
102 |
+
|
103 |
+
# pipeline config
|
104 |
+
self.pipeline_type: str = pipeline_config["pipeline_type"]
|
105 |
+
self.model_name: str = pipeline_config["model_name"]
|
106 |
+
|
107 |
+
# general model config
|
108 |
+
self.model_settings: dict = pipeline_config["model_settings"]
|
109 |
+
|
110 |
+
# 由于pipeline本身会解析参数,后续参数可以不要
|
111 |
+
# 直接将model_settings字典转为Namespace后作为pipeline的args参数即可
|
112 |
+
|
113 |
+
# self.device: int = self.model_settings["device"]
|
114 |
+
# self.texta_name: Optional[str] = self.model_settings["texta_name"]
|
115 |
+
# self.textb_name: Optional[str] = self.model_settings["textb_name"]
|
116 |
+
# self.label_name: Optional[str] = self.model_settings["label_name"]
|
117 |
+
# self.max_length: int = self.model_settings["max_length"]
|
118 |
+
# self.return_tensors: str = self.model_settings["return_tensors"]
|
119 |
+
# self.padding: str = self.model_settings["padding"]
|
120 |
+
# self.truncation: bool = self.model_settings["truncation"]
|
121 |
+
# self.skip_special_tokens: bool = self.model_settings["skip_special_tokens"]
|
122 |
+
# self.clean_up_tkenization_spaces: bool = self.model_settings["clean_up_tkenization_spaces"]
|
123 |
+
|
124 |
+
# # specific parameters for text2image model
|
125 |
+
# self.skip_steps: Optional[int] = self.model_settings["skip_steps"]
|
126 |
+
# self.clip_guidance_scale: Optional[int] = self.model_settings["clip_guidance_scale"]
|
127 |
+
# self.init_scale: Optional[int] = self.model_settings["init_scale"]
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
def setup_logger(logger, user_config: APIConfig):
|
132 |
+
|
133 |
+
# default level: INFO
|
134 |
+
|
135 |
+
logger.setLevel(getattr(logging, user_config.log_level, "INFO"))
|
136 |
+
ch = logging.StreamHandler()
|
137 |
+
|
138 |
+
if(user_config.log_file_path == ""):
|
139 |
+
fh = logging.FileHandler(filename = CURRENT_DIR_PATH + "/" + user_config.SERVER_NAME + ".log")
|
140 |
+
elif(".log" not in user_config.log_file_path[-5:-1]):
|
141 |
+
fh = logging.FileHandler(filename = user_config.log_file_path + "/" + user_config.SERVER_NAME + ".log")
|
142 |
+
else:
|
143 |
+
fh = logging.FileHandler(filename = user_config.log_file_path)
|
144 |
+
|
145 |
+
|
146 |
+
formatter = logging.Formatter(
|
147 |
+
"%(asctime)s - %(module)s - %(funcName)s - line:%(lineno)d - %(levelname)s - %(message)s"
|
148 |
+
)
|
149 |
+
|
150 |
+
ch.setFormatter(formatter)
|
151 |
+
fh.setFormatter(formatter)
|
152 |
+
logger.addHandler(ch) # Exporting logs to the screen
|
153 |
+
logger.addHandler(fh) # Exporting logs to a file
|
154 |
+
|
155 |
+
return logger
|
156 |
+
|
157 |
+
user_config = APIConfig()
|
158 |
+
api_logger = logging.getLogger()
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
fengshen/README.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## 最新发布
|
2 |
+
|
3 |
+
* \[2022.09.13\] [更新ErLangShen系列DeBERTa预训练代码](https://huggingface.co/IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-Chinese)
|
4 |
+
* \[2022.09.13\] [更新RanDeng系列Bart预训练代码](https://huggingface.co/IDEA-CCNL/Randeng-BART-139M)
|
5 |
+
* \[2022.09.13\] [更新ErLangShen系列Bert预训练代码](https://huggingface.co/IDEA-CCNL/Erlangshen-MegatronBert-1.3B)
|
6 |
+
* \[2022.05.11\] [更新TaiYi系列VIT多模态模型及下游任务示例](https://fengshenbang-doc.readthedocs.io/zh/latest/docs/太乙系列/Taiyi-vit-87M-D.html)
|
7 |
+
* \[2022.05.11\] [更新BiGan系列Transformer-XL去噪模型及下游任务示例](https://fengshenbang-doc.readthedocs.io/zh/latest/docs/比干系列/Bigan-Transformer-XL-denoise-1.1B.html)
|
8 |
+
* \[2022.05.11\] [更新ErLangShen系列下游任务示例](https://fengshenbang-doc.readthedocs.io/zh/latest/docs/二郎神系列/Erlangshen-Roberta-110M-NLI.html)
|
9 |
+
|
10 |
+
# 导航
|
11 |
+
|
12 |
+
- [导航](#导航)
|
13 |
+
- [框架简介](#框架简介)
|
14 |
+
- [依赖环境](#依赖环境)
|
15 |
+
- [项目结构](#项目结构)
|
16 |
+
- [设计思路](#设计思路)
|
17 |
+
- [分类下游任务](#分类下游任务)
|
18 |
+
|
19 |
+
## 框架简介
|
20 |
+
|
21 |
+
FengShen训练框架是封神榜大模型开源计划的重要一环,在大模型的生产和应用中起到至关重要的作用。FengShen可以应用在基于海量数据的预训练以及各种下游任务的finetune中。封神榜专注于NLP大模型开源,然而模型的增大带来不仅仅是训练的问题,在使用上也存在诸多不便。为了解决训练和使用的问题,FengShen参考了目前开源的优秀方案并且重新设计了Pipeline,用户可以根据自己的需求,从封神榜中选取丰富的预训练模型,同时利用FengShen快速微调下游任务。
|
22 |
+
|
23 |
+
目前所有实例以及文档可以查看我们的[Wiki](https://fengshenbang-doc.readthedocs.io/zh/latest/index.html)
|
24 |
+
所有的模型可以在[Huggingface主页](https://huggingface.co/IDEA-CCNL)找到
|
25 |
+
|
26 |
+
通过我们的框架,你可以快速享受到:
|
27 |
+
|
28 |
+
1. 比原生torch更强的性能,训练速度提升<font color=#0000FF >**300%**</font>
|
29 |
+
2. 支持更大的模型,支持<font color=#0000FF >**百亿级别**</font>内模型训练及微调
|
30 |
+
3. 支持<font color=#0000FF >**TB级以上**</font>的数据集,在家用主机上即可享受预训练模型带来的效果提升
|
31 |
+
3. 丰富的预训练、下游任务示例,一键开始训练
|
32 |
+
4. 适应各种设备环境,支持在CPU、GPU、TPU等不同设备上运行
|
33 |
+
5. 集成主流的分布式训练逻辑,无需修改代码即可支持DDP、Zero Optimizer等分布式优化技术
|
34 |
+
|
35 |
+
![avartar](../pics/fengshen_pic.png)
|
36 |
+
|
37 |
+
## 依赖环境
|
38 |
+
|
39 |
+
* Python >= 3.8
|
40 |
+
* torch >= 1.8
|
41 |
+
* transformers >= 3.2.0
|
42 |
+
* pytorch-lightning >= 1.5.10
|
43 |
+
|
44 |
+
在Fengshenbang-LM根目录下
|
45 |
+
pip install --editable ./
|
46 |
+
|
47 |
+
## 项目结构
|
48 |
+
|
49 |
+
```
|
50 |
+
├── data # 支持多种数据处理方式以及数据集
|
51 |
+
│ ├── cbart_dataloader
|
52 |
+
| ├── fs_datasets # 基于transformers datasets的封装,新增中文数据集(开源计划中)
|
53 |
+
| ├── universal_datamodule # 打通fs_datasets与lightning datamodule,减少重复开发工作量
|
54 |
+
│ ├── megatron_dataloader # 支持基于Megatron实现的TB级别数据集处理、训练
|
55 |
+
│ ├── mmap_dataloader # 通用的Memmap形式的数据加载
|
56 |
+
│ └── task_dataloader # 支持多种下游任务
|
57 |
+
├── examples # 丰富的示例,从预训练到下游任务,应有尽有。
|
58 |
+
├── metric # 提供各种metric计算,支持用户自定义metric
|
59 |
+
├── losses # 同样支持loss自定义,满足定制化需求
|
60 |
+
├── tokenizer # 支持自定义tokenizer,比如我们使用的SentencePiece训练代码等
|
61 |
+
├── models # 模型库
|
62 |
+
│ ├── auto # 支持自动导入对应的模型
|
63 |
+
│ ├── bart
|
64 |
+
│ ├── longformer
|
65 |
+
│ ├── megatron_t5
|
66 |
+
│ └── roformer
|
67 |
+
└── utils # 实用函数
|
68 |
+
```
|
69 |
+
|
70 |
+
## 设计思路
|
71 |
+
|
72 |
+
FengShen框架目前整体基于Pytorch-Lightning & Transformer进行开发,在底层框架上不断开源基于中文的预训练模型,同时提供丰富的examples,每一个封神榜的模型都能找到对应的预训练、下游任务代码。
|
73 |
+
|
74 |
+
在FengShen上开发,整体可以按照下面的三个步骤进行:
|
75 |
+
|
76 |
+
1. 封装数据处理流程 -> pytorch_lightning.LightningDataModule
|
77 |
+
2. 封装模型结构 -> pytorch_lightning.LightningModule
|
78 |
+
3. 配置一些插件,比如log_monitor,checkpoint_callback等等。
|
79 |
+
|
80 |
+
一个完整的DEMO可以看Randeng-BART系列实例 -> [文档](https://fengshenbang-doc.readthedocs.io/zh/latest/docs/燃灯系列/BART-139M.html) [代码](https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/hf-ds/fengshen/examples/pretrain_bart)
|
81 |
+
|
82 |
+
## 分类下游任务
|
83 |
+
|
84 |
+
在examples/classification目录下,我们提供丰富的分类任务的示例���其中我们提供三个一键式运行的示例。
|
85 |
+
|
86 |
+
* demo_classification_afqmc_roberta.sh 使用DDP微调roberta
|
87 |
+
* demo_classification_afqmc_roberta_deepspeed.sh 结合deepspeed微调roberta,获得更快的运算速度
|
88 |
+
* demo_classification_afqmc_erlangshen_offload.sh 仅需7G显存即可微调我们效果最好的二郎神系列模型
|
89 |
+
|
90 |
+
上述示例均采用AFQMC的数据集,关于数据集的介绍可以在[这里](https://www.cluebenchmarks.com/introduce.html)找到。
|
91 |
+
同时我们处理过的数据文件已经放在Huggingface上,点击[这里](https://huggingface.co/datasets/IDEA-CCNL/AFQMC)直达源文件。
|
92 |
+
仅需要按我们的格式稍微处理一下数据集,即可适配下游不同的分类任务。
|
93 |
+
在脚本示例中,仅需要修改如下参数即可适配本地文件
|
94 |
+
|
95 |
+
```
|
96 |
+
--dataset_name IDEA-CCNL/AFQMC \
|
97 |
+
|
98 |
+
-------> 修改为
|
99 |
+
|
100 |
+
--data_dir $DATA_DIR \ # 数据目录
|
101 |
+
--train_data train.json \ # 数据文件
|
102 |
+
--valid_data dev.json \
|
103 |
+
--test_data test.json \
|
104 |
+
|
105 |
+
```
|
fengshen/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The IDEA Authors. All rights reserved.
|
3 |
+
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from .models.longformer import LongformerConfig, LongformerModel
|
17 |
+
from .models.roformer import RoFormerConfig, RoFormerModel
|
18 |
+
from .models.megatron_t5 import T5Config, T5EncoderModel
|
19 |
+
from .models.ubert import UbertPipelines, UbertModel
|
fengshen/cli/fengshen_pipeline.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from importlib import import_module
|
3 |
+
from datasets import load_dataset
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
if len(sys.argv) < 3:
|
9 |
+
raise Exception(
|
10 |
+
'args len < 3, example: fengshen_pipeline text_classification predict xxxxx')
|
11 |
+
pipeline_name = sys.argv[1]
|
12 |
+
method = sys.argv[2]
|
13 |
+
pipeline_class = getattr(import_module('fengshen.pipelines.' + pipeline_name), 'Pipeline')
|
14 |
+
|
15 |
+
total_parser = argparse.ArgumentParser("FengShen Pipeline")
|
16 |
+
total_parser.add_argument('--model', default='', type=str)
|
17 |
+
total_parser.add_argument('--datasets', default='', type=str)
|
18 |
+
total_parser.add_argument('--text', default='', type=str)
|
19 |
+
total_parser = pipeline_class.add_pipeline_specific_args(total_parser)
|
20 |
+
args = total_parser.parse_args(args=sys.argv[3:])
|
21 |
+
pipeline = pipeline_class(args=args, model=args.model)
|
22 |
+
|
23 |
+
if method == 'predict':
|
24 |
+
print(pipeline(args.text))
|
25 |
+
elif method == 'train':
|
26 |
+
datasets = load_dataset(args.datasets)
|
27 |
+
pipeline.train(datasets)
|
28 |
+
else:
|
29 |
+
raise Exception(
|
30 |
+
'cmd not support, now only support {predict, train}')
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
main()
|
fengshen/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# coding=utf-8
|
fengshen/data/bert_dataloader/auto_split.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
files=`find $1 -type f -size +1024M`
|
2 |
+
|
3 |
+
for p in $files
|
4 |
+
do
|
5 |
+
echo "processing $p"
|
6 |
+
name=`basename $p .json`
|
7 |
+
file=`dirname $p`
|
8 |
+
split -a 2 -C 300M $p $file/$name- && ls|grep -E "(-[a-zA-Z]{2})" |xargs -n1 -i{} mv {} {}.json
|
9 |
+
rm -f $p
|
10 |
+
done
|
fengshen/data/bert_dataloader/load.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from pathlib import Path
|
4 |
+
import glob
|
5 |
+
from tqdm import tqdm
|
6 |
+
from contextlib import ExitStack
|
7 |
+
import datasets
|
8 |
+
import multiprocessing
|
9 |
+
from typing import cast, TextIO
|
10 |
+
from itertools import chain
|
11 |
+
import json
|
12 |
+
from concurrent.futures import ProcessPoolExecutor
|
13 |
+
from random import shuffle
|
14 |
+
from pytorch_lightning import LightningDataModule
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
|
19 |
+
|
20 |
+
# _SPLIT_DATA_PATH = '/data1/datas/wudao_180g_split/test'
|
21 |
+
_SPLIT_DATA_PATH = '/data1/datas/wudao_180g_split'
|
22 |
+
_CACHE_SPLIT_DATA_PATH = '/data1/datas/wudao_180g_FSData'
|
23 |
+
|
24 |
+
# feats = datasets.Features({"text": datasets.Value('string')})
|
25 |
+
|
26 |
+
|
27 |
+
class BertDataGenerate(object):
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
data_files=_SPLIT_DATA_PATH,
|
31 |
+
save_path=_CACHE_SPLIT_DATA_PATH,
|
32 |
+
train_test_validation='950,49,1',
|
33 |
+
num_proc=1,
|
34 |
+
cache=True):
|
35 |
+
self.data_files = Path(data_files)
|
36 |
+
if save_path:
|
37 |
+
self.save_path = Path(save_path)
|
38 |
+
else:
|
39 |
+
self.save_path = self.file_check(
|
40 |
+
Path(self.data_files.parent, self.data_files.name+'_FSDataset'),
|
41 |
+
'save')
|
42 |
+
self.num_proc = num_proc
|
43 |
+
self.cache = cache
|
44 |
+
self.split_idx = self.split_train_test_validation_index(train_test_validation)
|
45 |
+
if cache:
|
46 |
+
self.cache_path = self.file_check(
|
47 |
+
Path(self.save_path.parent, 'FSDataCache', self.data_files.name), 'cache')
|
48 |
+
else:
|
49 |
+
self.cache_path = None
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def file_check(path, path_type):
|
53 |
+
print(path)
|
54 |
+
if not path.exists():
|
55 |
+
path.mkdir(parents=True)
|
56 |
+
print(f"Since no {path_type} directory is specified, the program will automatically create it in {path} directory.")
|
57 |
+
return str(path)
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def split_train_test_validation_index(train_test_validation):
|
61 |
+
split_idx_ = [int(i) for i in train_test_validation.split(',')]
|
62 |
+
idx_dict = {
|
63 |
+
'train_rate': split_idx_[0]/sum(split_idx_),
|
64 |
+
'test_rate': split_idx_[1]/sum(split_idx_[1:])
|
65 |
+
}
|
66 |
+
return idx_dict
|
67 |
+
|
68 |
+
def process(self, index, path):
|
69 |
+
print('saving dataset shard {}'.format(index))
|
70 |
+
|
71 |
+
ds = (datasets.load_dataset('json', data_files=str(path),
|
72 |
+
cache_dir=self.cache_path,
|
73 |
+
features=None))
|
74 |
+
# ds = ds.map(self.cut_sent,input_columns='text')
|
75 |
+
# print(d)
|
76 |
+
# print('!!!',ds)
|
77 |
+
ds = ds['train'].train_test_split(train_size=self.split_idx['train_rate'])
|
78 |
+
ds_ = ds['test'].train_test_split(train_size=self.split_idx['test_rate'])
|
79 |
+
ds = datasets.DatasetDict({
|
80 |
+
'train': ds['train'],
|
81 |
+
'test': ds_['train'],
|
82 |
+
'validation': ds_['test']
|
83 |
+
})
|
84 |
+
# print('!!!!',ds)
|
85 |
+
ds.save_to_disk(Path(self.save_path, path.name))
|
86 |
+
return 'saving dataset shard {} done'.format(index)
|
87 |
+
|
88 |
+
def generate_cache_arrow(self) -> None:
|
89 |
+
'''
|
90 |
+
生成HF支持的缓存文件,加速后续的加载
|
91 |
+
'''
|
92 |
+
data_dict_paths = self.data_files.rglob('*')
|
93 |
+
p = ProcessPoolExecutor(max_workers=self.num_proc)
|
94 |
+
res = list()
|
95 |
+
|
96 |
+
for index, path in enumerate(data_dict_paths):
|
97 |
+
res.append(p.submit(self.process, index, path))
|
98 |
+
|
99 |
+
p.shutdown(wait=True)
|
100 |
+
for future in res:
|
101 |
+
print(future.result(), flush=True)
|
102 |
+
|
103 |
+
|
104 |
+
def load_dataset(num_proc=4, **kargs):
|
105 |
+
cache_dict_paths = Path(_CACHE_SPLIT_DATA_PATH).glob('*')
|
106 |
+
ds = []
|
107 |
+
res = []
|
108 |
+
p = ProcessPoolExecutor(max_workers=num_proc)
|
109 |
+
for path in cache_dict_paths:
|
110 |
+
res.append(p.submit(datasets.load_from_disk,
|
111 |
+
str(path), **kargs))
|
112 |
+
|
113 |
+
p.shutdown(wait=True)
|
114 |
+
for future in res:
|
115 |
+
ds.append(future.result())
|
116 |
+
# print(future.result())
|
117 |
+
train = []
|
118 |
+
test = []
|
119 |
+
validation = []
|
120 |
+
for ds_ in ds:
|
121 |
+
train.append(ds_['train'])
|
122 |
+
test.append(ds_['test'])
|
123 |
+
validation.append(ds_['validation'])
|
124 |
+
# ds = datasets.concatenate_datasets(ds)
|
125 |
+
# print(ds)
|
126 |
+
return datasets.DatasetDict({
|
127 |
+
'train': datasets.concatenate_datasets(train),
|
128 |
+
'test': datasets.concatenate_datasets(test),
|
129 |
+
'validation': datasets.concatenate_datasets(validation)
|
130 |
+
})
|
131 |
+
|
132 |
+
|
133 |
+
class BertDataModule(LightningDataModule):
|
134 |
+
@ staticmethod
|
135 |
+
def add_data_specific_args(parent_args):
|
136 |
+
parser = parent_args.add_argument_group('Universal DataModule')
|
137 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
138 |
+
parser.add_argument('--train_batchsize', default=32, type=int)
|
139 |
+
parser.add_argument('--val_batchsize', default=32, type=int)
|
140 |
+
parser.add_argument('--test_batchsize', default=32, type=int)
|
141 |
+
parser.add_argument('--datasets_name', type=str)
|
142 |
+
# parser.add_argument('--datasets_name', type=str)
|
143 |
+
parser.add_argument('--train_datasets_field', type=str, default='train')
|
144 |
+
parser.add_argument('--val_datasets_field', type=str, default='validation')
|
145 |
+
parser.add_argument('--test_datasets_field', type=str, default='test')
|
146 |
+
return parent_args
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
tokenizer,
|
151 |
+
collate_fn,
|
152 |
+
args,
|
153 |
+
**kwargs,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
self.datasets = load_dataset(num_proc=args.num_workers)
|
157 |
+
self.tokenizer = tokenizer
|
158 |
+
self.collate_fn = collate_fn
|
159 |
+
self.save_hyperparameters(args)
|
160 |
+
|
161 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
162 |
+
self.train = DataLoader(
|
163 |
+
self.datasets[self.hparams.train_datasets_field],
|
164 |
+
batch_size=self.hparams.train_batchsize,
|
165 |
+
shuffle=True,
|
166 |
+
num_workers=self.hparams.num_workers,
|
167 |
+
collate_fn=self.collate_fn,
|
168 |
+
)
|
169 |
+
self.val = DataLoader(
|
170 |
+
self.datasets[self.hparams.val_datasets_field],
|
171 |
+
batch_size=self.hparams.val_batchsize,
|
172 |
+
shuffle=False,
|
173 |
+
num_workers=self.hparams.num_workers,
|
174 |
+
collate_fn=self.collate_fn,
|
175 |
+
)
|
176 |
+
self.test = DataLoader(
|
177 |
+
self.datasets[self.hparams.test_datasets_field],
|
178 |
+
batch_size=self.hparams.test_batchsize,
|
179 |
+
shuffle=False,
|
180 |
+
num_workers=self.hparams.num_workers,
|
181 |
+
collate_fn=self.collate_fn,
|
182 |
+
)
|
183 |
+
return
|
184 |
+
|
185 |
+
def train_dataloader(self):
|
186 |
+
return self.train
|
187 |
+
|
188 |
+
def val_dataloader(self):
|
189 |
+
return self.val
|
190 |
+
|
191 |
+
def test_dataloader(self):
|
192 |
+
return self.test
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
# pre = PreProcessing(_SPLIT_DATA_PATH)
|
197 |
+
# pre.processing()
|
198 |
+
|
199 |
+
dataset = BertDataGenerate(_SPLIT_DATA_PATH, num_proc=16)
|
200 |
+
dataset.generate_cache_arrow()
|
fengshen/data/bert_dataloader/preprocessing.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import multiprocessing
|
4 |
+
from tqdm import tqdm
|
5 |
+
from pathlib import Path
|
6 |
+
from itertools import chain
|
7 |
+
|
8 |
+
_SPLIT_DATA_PATH = '/data1/datas/wudao_180g'
|
9 |
+
|
10 |
+
|
11 |
+
def cut_sent(path):
|
12 |
+
"""
|
13 |
+
中文分句,默认?、。、!、省略号分句,考虑双引号包裹的句子
|
14 |
+
采用分割替换的方式
|
15 |
+
"""
|
16 |
+
path = Path(path)
|
17 |
+
# print(path)
|
18 |
+
save_path = str(Path('/data1/datas/wudao_180g_split', path.name))
|
19 |
+
print('处理文件:', save_path)
|
20 |
+
with open(save_path, 'wt', encoding='utf-8') as w:
|
21 |
+
with open(path, 'rt', encoding='utf-8') as f:
|
22 |
+
for para in tqdm(f):
|
23 |
+
para = json.loads(para)
|
24 |
+
para_ = para['text'] + ' '
|
25 |
+
# print('sentence piece......')
|
26 |
+
# pep8中 正则不能些 \? 要写成\\?
|
27 |
+
para_ = re.sub('([?。!\\?\\!…]+)([^”’]|[”’])',
|
28 |
+
r'\1#####\2', para_)
|
29 |
+
para_ = re.sub('([\\.]{3,})([^”’])', r'\1#####\2', para_)
|
30 |
+
|
31 |
+
# 匹配 \1: 句子结束符紧挨’” \2: 非句子结束符号,被引号包裹的句子
|
32 |
+
para_ = re.sub(
|
33 |
+
'([。!?\\?\\!…][”’])([^,。!?\\?\\!]|\\s)', r'\1#####\2', para_)
|
34 |
+
para_ = re.sub(
|
35 |
+
'([\\.]{3,}[”’])([^,。!?\\?\\!]|\\s)', r'\1#####\2', para_)
|
36 |
+
para_ = re.sub(
|
37 |
+
'([#]{5})([”’])([^,。!?\\?\\!])', r'\2#####\3', para_)
|
38 |
+
para_ = para_.strip()
|
39 |
+
# 一个512里面多个样本
|
40 |
+
line_ = ''
|
41 |
+
for line in para_.split('#####'):
|
42 |
+
line = line.strip()
|
43 |
+
if len(line_) < 512 and len(line) > 0:
|
44 |
+
line_ += line
|
45 |
+
else:
|
46 |
+
w.writelines(json.dumps(
|
47 |
+
{'text': line_}, ensure_ascii=False)+'\n')
|
48 |
+
line_ = line
|
49 |
+
w.writelines(json.dumps(
|
50 |
+
{'text': line_}, ensure_ascii=False)+'\n')
|
51 |
+
|
52 |
+
|
53 |
+
def chain_iter(*filenames):
|
54 |
+
"""
|
55 |
+
将多个文件读成一个迭代器
|
56 |
+
"""
|
57 |
+
reader = [open(file, 'r') for file in filenames]
|
58 |
+
return chain(*reader)
|
59 |
+
|
60 |
+
|
61 |
+
class Config(object):
|
62 |
+
|
63 |
+
def __init__(self, data_path=_SPLIT_DATA_PATH, num_worker=16, split_numb=600000, cut_sentence=True, output_file=None) -> None:
|
64 |
+
self.data_path = Path(data_path)
|
65 |
+
self.num_worker = num_worker
|
66 |
+
self.split_numb = split_numb
|
67 |
+
self.cut_sentence = cut_sentence
|
68 |
+
|
69 |
+
|
70 |
+
def processing1():
|
71 |
+
args = Config()
|
72 |
+
p_ = [str(i) for i in args.data_path.glob('*')]
|
73 |
+
fin = chain_iter(*p_)
|
74 |
+
pool = multiprocessing.Pool(args.num_worker)
|
75 |
+
docs = pool.imap(cut_sent, fin, chunksize=args.num_worker)
|
76 |
+
|
77 |
+
if not Path(args.data_path.parent, args.data_path.name+'_split').exists():
|
78 |
+
Path(args.data_path.parent, args.data_path.name+'_split').mkdir()
|
79 |
+
writer = open(str(Path(args.data_path.parent, args.data_path.name +
|
80 |
+
'_split', 'sentence_level.json')), 'wt', encoding='utf-8')
|
81 |
+
for doc in tqdm(docs):
|
82 |
+
for sentence in doc:
|
83 |
+
writer.writelines(json.dumps(
|
84 |
+
{"text": sentence}, ensure_ascii=False)+'\n')
|
85 |
+
pool.close()
|
86 |
+
pool.join()
|
87 |
+
writer.close()
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
from time import process_time, perf_counter
|
92 |
+
from random import shuffle
|
93 |
+
st = process_time()
|
94 |
+
args = Config(num_worker=16)
|
95 |
+
|
96 |
+
if not Path(args.data_path.parent, args.data_path.name+'_split').exists():
|
97 |
+
Path(args.data_path.parent, args.data_path.name +
|
98 |
+
'_split').mkdir(parents=True)
|
99 |
+
|
100 |
+
p_ = [str(i) for i in args.data_path.glob('*')]
|
101 |
+
# 简单shuffle
|
102 |
+
shuffle(p_)
|
103 |
+
|
104 |
+
pool = multiprocessing.Pool(args.num_worker)
|
105 |
+
for item in p_:
|
106 |
+
pool.apply_async(func=cut_sent, args=(item,))
|
107 |
+
pool.close()
|
108 |
+
pool.join()
|
109 |
+
cost_time = process_time() - st
|
110 |
+
print('DONE!! cost time : %.5f' % cost_time)
|
fengshen/data/clip_dataloader/flickr.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset, DataLoader
|
2 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
3 |
+
CenterCrop
|
4 |
+
from transformers import BertTokenizer
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from PIL import Image
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
class flickr30k_CNA(Dataset):
|
11 |
+
def __init__(self, img_root_path,
|
12 |
+
annot_path,
|
13 |
+
transform=None):
|
14 |
+
self.images = []
|
15 |
+
self.captions = []
|
16 |
+
self.labels = []
|
17 |
+
self.root = img_root_path
|
18 |
+
with open(annot_path, 'r') as f:
|
19 |
+
for line in f:
|
20 |
+
line = line.strip().split('\t')
|
21 |
+
key, caption = line[0].split('#')[0], line[1]
|
22 |
+
img_path = key + '.jpg'
|
23 |
+
self.images.append(img_path)
|
24 |
+
self.captions.append(caption)
|
25 |
+
self.labels.append(key)
|
26 |
+
self.transforms = transform
|
27 |
+
self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
|
28 |
+
|
29 |
+
# NOTE large 模型
|
30 |
+
self.context_length = 77
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.images)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
img_path = str(self.images[idx])
|
37 |
+
image = self.transforms(Image.open(os.path.join(self.root, img_path)))
|
38 |
+
text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length,
|
39 |
+
padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]
|
40 |
+
label = self.labels[idx]
|
41 |
+
return image, text, label
|
42 |
+
|
43 |
+
|
44 |
+
def _convert_to_rgb(image):
|
45 |
+
return image.convert('RGB')
|
46 |
+
|
47 |
+
|
48 |
+
def image_transform(
|
49 |
+
image_size: int,
|
50 |
+
is_train: bool,
|
51 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
52 |
+
std=(0.26862954, 0.26130258, 0.27577711)
|
53 |
+
):
|
54 |
+
normalize = Normalize(mean=mean, std=std)
|
55 |
+
if is_train:
|
56 |
+
return Compose([
|
57 |
+
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
|
58 |
+
_convert_to_rgb,
|
59 |
+
ToTensor(),
|
60 |
+
normalize,
|
61 |
+
])
|
62 |
+
else:
|
63 |
+
return Compose([
|
64 |
+
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
65 |
+
CenterCrop(image_size),
|
66 |
+
_convert_to_rgb,
|
67 |
+
ToTensor(),
|
68 |
+
normalize,
|
69 |
+
])
|
70 |
+
|
71 |
+
|
72 |
+
class FlickrDataModule(pl.LightningDataModule):
|
73 |
+
def __init__(self, args):
|
74 |
+
self.batch_size = args.batch_size
|
75 |
+
self.train_filename = args.train_filename # NOTE 标注的文件夹
|
76 |
+
self.train_root = args.train_root # NOTE 图片地址
|
77 |
+
self.val_filename = args.val_filename
|
78 |
+
self.val_root = args.val_root
|
79 |
+
self.test_filename = args.test_filename
|
80 |
+
self.test_root = args.test_root
|
81 |
+
|
82 |
+
self.pretrain_model = args.pretrain_model
|
83 |
+
self.image_size = 224
|
84 |
+
self.prepare_data_per_node = True
|
85 |
+
self._log_hyperparams = False
|
86 |
+
self.num_workers = args.num_workers
|
87 |
+
|
88 |
+
def setup(self, stage=None):
|
89 |
+
# dataset
|
90 |
+
train_transform = image_transform(224, True)
|
91 |
+
val_transform = image_transform(224, False)
|
92 |
+
test_transform = image_transform(224, False)
|
93 |
+
|
94 |
+
self.train_dataset = flickr30k_CNA(self.train_root, self.train_filename, transform=train_transform)
|
95 |
+
self.val_dataset = flickr30k_CNA(self.val_root, self.val_filename, transform=val_transform)
|
96 |
+
self.test_dataset = flickr30k_CNA(self.test_root, self.test_filename, transform=test_transform)
|
97 |
+
|
98 |
+
def train_dataloader(self):
|
99 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
100 |
+
|
101 |
+
def val_dataloader(self):
|
102 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
103 |
+
|
104 |
+
def test_dataloader(self):
|
105 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
fengshen/data/data_utils/common_utils.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def padding_to_maxlength(ids, max_length, pad_id):
|
2 |
+
cur_len = len(ids)
|
3 |
+
len_diff = max_length - len(ids)
|
4 |
+
return ids + [pad_id] * len_diff, [1] * cur_len + [0] * len_diff
|
fengshen/data/data_utils/mask_utils.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
6 |
+
["index", "label"])
|
7 |
+
|
8 |
+
|
9 |
+
def is_start_piece(piece):
|
10 |
+
"""Check if the current word piece is the starting piece (BERT)."""
|
11 |
+
# When a word has been split into
|
12 |
+
# WordPieces, the first token does not have any marker and any subsequence
|
13 |
+
# tokens are prefixed with ##. So whenever we see the ## token, we
|
14 |
+
# append it to the previous set of word indexes.
|
15 |
+
return not piece.startswith("##")
|
16 |
+
|
17 |
+
|
18 |
+
def create_masked_lm_predictions(tokens,
|
19 |
+
vocab_id_list, vocab_id_to_token_dict,
|
20 |
+
masked_lm_prob,
|
21 |
+
cls_id, sep_id, mask_id,
|
22 |
+
max_predictions_per_seq,
|
23 |
+
np_rng,
|
24 |
+
max_ngrams=3,
|
25 |
+
do_whole_word_mask=True,
|
26 |
+
favor_longer_ngram=False,
|
27 |
+
do_permutation=False,
|
28 |
+
geometric_dist=False,
|
29 |
+
masking_style="bert",
|
30 |
+
zh_tokenizer=None):
|
31 |
+
"""Creates the predictions for the masked LM objective.
|
32 |
+
Note: Tokens here are vocab ids and not text tokens."""
|
33 |
+
'''
|
34 |
+
modified from Megatron-LM
|
35 |
+
Args:
|
36 |
+
tokens: 输入
|
37 |
+
vocab_id_list: 词表token_id_list
|
38 |
+
vocab_id_to_token_dict: token_id到token字典
|
39 |
+
masked_lm_prob:mask概率
|
40 |
+
cls_id、sep_id、mask_id:特殊token
|
41 |
+
max_predictions_per_seq:最大mask个数
|
42 |
+
np_rng:mask随机数
|
43 |
+
max_ngrams:最大词长度
|
44 |
+
do_whole_word_mask:是否做全词掩码
|
45 |
+
favor_longer_ngram:优先用长的词
|
46 |
+
do_permutation:是否打乱
|
47 |
+
geometric_dist:用np_rng.geometric做随机
|
48 |
+
masking_style:mask类型
|
49 |
+
zh_tokenizer:WWM的分词器,比如用jieba.lcut做分词之类的
|
50 |
+
'''
|
51 |
+
cand_indexes = []
|
52 |
+
# Note(mingdachen): We create a list for recording if the piece is
|
53 |
+
# the starting piece of current token, where 1 means true, so that
|
54 |
+
# on-the-fly whole word masking is possible.
|
55 |
+
token_boundary = [0] * len(tokens)
|
56 |
+
# 如果没有指定中文分词器,那就直接按##算
|
57 |
+
if zh_tokenizer is None:
|
58 |
+
for (i, token) in enumerate(tokens):
|
59 |
+
if token == cls_id or token == sep_id:
|
60 |
+
token_boundary[i] = 1
|
61 |
+
continue
|
62 |
+
# Whole Word Masking means that if we mask all of the wordpieces
|
63 |
+
# corresponding to an original word.
|
64 |
+
#
|
65 |
+
# Note that Whole Word Masking does *not* change the training code
|
66 |
+
# at all -- we still predict each WordPiece independently, softmaxed
|
67 |
+
# over the entire vocabulary.
|
68 |
+
if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
69 |
+
not is_start_piece(vocab_id_to_token_dict[token])):
|
70 |
+
cand_indexes[-1].append(i)
|
71 |
+
else:
|
72 |
+
cand_indexes.append([i])
|
73 |
+
if is_start_piece(vocab_id_to_token_dict[token]):
|
74 |
+
token_boundary[i] = 1
|
75 |
+
else:
|
76 |
+
# 如果指定了中文分词器,那就先用分词器分词,然后再进行判断
|
77 |
+
# 获取去掉CLS SEP的原始文本
|
78 |
+
raw_tokens = []
|
79 |
+
for t in tokens:
|
80 |
+
if t != cls_id and t != sep_id:
|
81 |
+
raw_tokens.append(t)
|
82 |
+
raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens]
|
83 |
+
# 分词然后获取每次字开头的最长词的长度
|
84 |
+
word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True))
|
85 |
+
word_length_dict = {}
|
86 |
+
for w in word_list:
|
87 |
+
if len(w) < 1:
|
88 |
+
continue
|
89 |
+
if w[0] not in word_length_dict:
|
90 |
+
word_length_dict[w[0]] = len(w)
|
91 |
+
elif word_length_dict[w[0]] < len(w):
|
92 |
+
word_length_dict[w[0]] = len(w)
|
93 |
+
i = 0
|
94 |
+
# 从词表里面检索
|
95 |
+
while i < len(tokens):
|
96 |
+
token_id = tokens[i]
|
97 |
+
token = vocab_id_to_token_dict[token_id]
|
98 |
+
if len(token) == 0 or token_id == cls_id or token_id == sep_id:
|
99 |
+
token_boundary[i] = 1
|
100 |
+
i += 1
|
101 |
+
continue
|
102 |
+
word_max_length = 1
|
103 |
+
if token[0] in word_length_dict:
|
104 |
+
word_max_length = word_length_dict[token[0]]
|
105 |
+
j = 0
|
106 |
+
word = ''
|
107 |
+
word_end = i+1
|
108 |
+
# 兼容以前##的形式,如果后面的词是##开头的,那么直接把后面的拼到前面当作一个词
|
109 |
+
old_style = False
|
110 |
+
while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'):
|
111 |
+
old_style = True
|
112 |
+
word_end += 1
|
113 |
+
if not old_style:
|
114 |
+
while j < word_max_length and i+j < len(tokens):
|
115 |
+
cur_token = tokens[i+j]
|
116 |
+
word += vocab_id_to_token_dict[cur_token]
|
117 |
+
j += 1
|
118 |
+
if word in word_list:
|
119 |
+
word_end = i+j
|
120 |
+
cand_indexes.append([p for p in range(i, word_end)])
|
121 |
+
token_boundary[i] = 1
|
122 |
+
i = word_end
|
123 |
+
|
124 |
+
output_tokens = list(tokens)
|
125 |
+
|
126 |
+
masked_lm_positions = []
|
127 |
+
masked_lm_labels = []
|
128 |
+
|
129 |
+
if masked_lm_prob == 0:
|
130 |
+
return (output_tokens, masked_lm_positions,
|
131 |
+
masked_lm_labels, token_boundary)
|
132 |
+
|
133 |
+
num_to_predict = min(max_predictions_per_seq,
|
134 |
+
max(1, int(round(len(tokens) * masked_lm_prob))))
|
135 |
+
|
136 |
+
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
137 |
+
if not geometric_dist:
|
138 |
+
# Note(mingdachen):
|
139 |
+
# By default, we set the probilities to favor shorter ngram sequences.
|
140 |
+
pvals = 1. / np.arange(1, max_ngrams + 1)
|
141 |
+
pvals /= pvals.sum(keepdims=True)
|
142 |
+
if favor_longer_ngram:
|
143 |
+
pvals = pvals[::-1]
|
144 |
+
# 获取一个ngram的idx,对于每个word,记录他的ngram的word
|
145 |
+
ngram_indexes = []
|
146 |
+
for idx in range(len(cand_indexes)):
|
147 |
+
ngram_index = []
|
148 |
+
for n in ngrams:
|
149 |
+
ngram_index.append(cand_indexes[idx:idx + n])
|
150 |
+
ngram_indexes.append(ngram_index)
|
151 |
+
|
152 |
+
np_rng.shuffle(ngram_indexes)
|
153 |
+
|
154 |
+
(masked_lms, masked_spans) = ([], [])
|
155 |
+
covered_indexes = set()
|
156 |
+
for cand_index_set in ngram_indexes:
|
157 |
+
if len(masked_lms) >= num_to_predict:
|
158 |
+
break
|
159 |
+
if not cand_index_set:
|
160 |
+
continue
|
161 |
+
# Note(mingdachen):
|
162 |
+
# Skip current piece if they are covered in lm masking or previous ngrams.
|
163 |
+
for index_set in cand_index_set[0]:
|
164 |
+
for index in index_set:
|
165 |
+
if index in covered_indexes:
|
166 |
+
continue
|
167 |
+
|
168 |
+
if not geometric_dist:
|
169 |
+
n = np_rng.choice(ngrams[:len(cand_index_set)],
|
170 |
+
p=pvals[:len(cand_index_set)] /
|
171 |
+
pvals[:len(cand_index_set)].sum(keepdims=True))
|
172 |
+
else:
|
173 |
+
# Sampling "n" from the geometric distribution and clipping it to
|
174 |
+
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
|
175 |
+
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
|
176 |
+
n = min(np_rng.geometric(0.2), max_ngrams)
|
177 |
+
|
178 |
+
index_set = sum(cand_index_set[n - 1], [])
|
179 |
+
n -= 1
|
180 |
+
# Note(mingdachen):
|
181 |
+
# Repeatedly looking for a candidate that does not exceed the
|
182 |
+
# maximum number of predictions by trying shorter ngrams.
|
183 |
+
while len(masked_lms) + len(index_set) > num_to_predict:
|
184 |
+
if n == 0:
|
185 |
+
break
|
186 |
+
index_set = sum(cand_index_set[n - 1], [])
|
187 |
+
n -= 1
|
188 |
+
# If adding a whole-word mask would exceed the maximum number of
|
189 |
+
# predictions, then just skip this candidate.
|
190 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
191 |
+
continue
|
192 |
+
is_any_index_covered = False
|
193 |
+
for index in index_set:
|
194 |
+
if index in covered_indexes:
|
195 |
+
is_any_index_covered = True
|
196 |
+
break
|
197 |
+
if is_any_index_covered:
|
198 |
+
continue
|
199 |
+
for index in index_set:
|
200 |
+
covered_indexes.add(index)
|
201 |
+
masked_token = None
|
202 |
+
token_id = tokens[index]
|
203 |
+
if masking_style == "bert":
|
204 |
+
# 80% of the time, replace with [MASK]
|
205 |
+
if np_rng.random() < 0.8:
|
206 |
+
masked_token = mask_id
|
207 |
+
else:
|
208 |
+
# 10% of the time, keep original
|
209 |
+
if np_rng.random() < 0.5:
|
210 |
+
masked_token = tokens[index]
|
211 |
+
# 10% of the time, replace with random word
|
212 |
+
else:
|
213 |
+
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
|
214 |
+
elif masking_style == "t5":
|
215 |
+
masked_token = mask_id
|
216 |
+
else:
|
217 |
+
raise ValueError("invalid value of masking style")
|
218 |
+
|
219 |
+
output_tokens[index] = masked_token
|
220 |
+
masked_lms.append(MaskedLmInstance(index=index, label=token_id))
|
221 |
+
|
222 |
+
masked_spans.append(MaskedLmInstance(
|
223 |
+
index=index_set,
|
224 |
+
label=[tokens[index] for index in index_set]))
|
225 |
+
|
226 |
+
assert len(masked_lms) <= num_to_predict
|
227 |
+
np_rng.shuffle(ngram_indexes)
|
228 |
+
|
229 |
+
select_indexes = set()
|
230 |
+
if do_permutation:
|
231 |
+
for cand_index_set in ngram_indexes:
|
232 |
+
if len(select_indexes) >= num_to_predict:
|
233 |
+
break
|
234 |
+
if not cand_index_set:
|
235 |
+
continue
|
236 |
+
# Note(mingdachen):
|
237 |
+
# Skip current piece if they are covered in lm masking or previous ngrams.
|
238 |
+
for index_set in cand_index_set[0]:
|
239 |
+
for index in index_set:
|
240 |
+
if index in covered_indexes or index in select_indexes:
|
241 |
+
continue
|
242 |
+
|
243 |
+
n = np.random.choice(ngrams[:len(cand_index_set)],
|
244 |
+
p=pvals[:len(cand_index_set)] /
|
245 |
+
pvals[:len(cand_index_set)].sum(keepdims=True))
|
246 |
+
index_set = sum(cand_index_set[n - 1], [])
|
247 |
+
n -= 1
|
248 |
+
|
249 |
+
while len(select_indexes) + len(index_set) > num_to_predict:
|
250 |
+
if n == 0:
|
251 |
+
break
|
252 |
+
index_set = sum(cand_index_set[n - 1], [])
|
253 |
+
n -= 1
|
254 |
+
# If adding a whole-word mask would exceed the maximum number of
|
255 |
+
# predictions, then just skip this candidate.
|
256 |
+
if len(select_indexes) + len(index_set) > num_to_predict:
|
257 |
+
continue
|
258 |
+
is_any_index_covered = False
|
259 |
+
for index in index_set:
|
260 |
+
if index in covered_indexes or index in select_indexes:
|
261 |
+
is_any_index_covered = True
|
262 |
+
break
|
263 |
+
if is_any_index_covered:
|
264 |
+
continue
|
265 |
+
for index in index_set:
|
266 |
+
select_indexes.add(index)
|
267 |
+
assert len(select_indexes) <= num_to_predict
|
268 |
+
|
269 |
+
select_indexes = sorted(select_indexes)
|
270 |
+
permute_indexes = list(select_indexes)
|
271 |
+
np_rng.shuffle(permute_indexes)
|
272 |
+
orig_token = list(output_tokens)
|
273 |
+
|
274 |
+
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
275 |
+
output_tokens[src_i] = orig_token[tgt_i]
|
276 |
+
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
|
277 |
+
|
278 |
+
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
279 |
+
# Sort the spans by the index of the first span
|
280 |
+
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
|
281 |
+
|
282 |
+
for p in masked_lms:
|
283 |
+
masked_lm_positions.append(p.index)
|
284 |
+
masked_lm_labels.append(p.label)
|
285 |
+
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
|
fengshen/data/data_utils/sentence_split.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
class ChineseSentenceSplitter(object):
|
5 |
+
def merge_symmetry(self, sentences, symmetry=('“', '”')):
|
6 |
+
# '''合并对称符号,如双引号'''
|
7 |
+
effective_ = []
|
8 |
+
merged = True
|
9 |
+
for index in range(len(sentences)):
|
10 |
+
if symmetry[0] in sentences[index] and symmetry[1] not in sentences[index]:
|
11 |
+
merged = False
|
12 |
+
effective_.append(sentences[index])
|
13 |
+
elif symmetry[1] in sentences[index] and not merged:
|
14 |
+
merged = True
|
15 |
+
effective_[-1] += sentences[index]
|
16 |
+
elif symmetry[0] not in sentences[index] and symmetry[1] not in sentences[index] and not merged:
|
17 |
+
effective_[-1] += sentences[index]
|
18 |
+
else:
|
19 |
+
effective_.append(sentences[index])
|
20 |
+
return [i.strip() for i in effective_ if len(i.strip()) > 0]
|
21 |
+
|
22 |
+
def to_sentences(self, paragraph):
|
23 |
+
# """由段落切分成句子"""
|
24 |
+
sentences = re.split(r"(?|。|[!]+|!|\…\…)", paragraph)
|
25 |
+
sentences.append("")
|
26 |
+
sentences = ["".join(i) for i in zip(sentences[0::2], sentences[1::2])]
|
27 |
+
sentences = [i.strip() for i in sentences if len(i.strip()) > 0]
|
28 |
+
for j in range(1, len(sentences)):
|
29 |
+
if sentences[j][0] == '”':
|
30 |
+
sentences[j-1] = sentences[j-1] + '”'
|
31 |
+
sentences[j] = sentences[j][1:]
|
32 |
+
return self.merge_symmetry(sentences)
|
33 |
+
|
34 |
+
def tokenize(self, text):
|
35 |
+
return self.to_sentences(text)
|
fengshen/data/data_utils/sop_utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# copy from megatron
|
3 |
+
def get_a_and_b_segments(sample, np_rng):
|
4 |
+
"""Divide sample into a and b segments."""
|
5 |
+
|
6 |
+
# Number of sentences in the sample.
|
7 |
+
n_sentences = len(sample)
|
8 |
+
# Make sure we always have two sentences.
|
9 |
+
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
|
10 |
+
|
11 |
+
# First part:
|
12 |
+
# `a_end` is how many sentences go into the `A`.
|
13 |
+
a_end = 1
|
14 |
+
if n_sentences >= 3:
|
15 |
+
# Note that randin in numpy is exclusive.
|
16 |
+
a_end = np_rng.randint(1, n_sentences)
|
17 |
+
tokens_a = []
|
18 |
+
for j in range(a_end):
|
19 |
+
tokens_a.extend(sample[j])
|
20 |
+
|
21 |
+
# Second part:
|
22 |
+
tokens_b = []
|
23 |
+
for j in range(a_end, n_sentences):
|
24 |
+
tokens_b.extend(sample[j])
|
25 |
+
|
26 |
+
# Random next:
|
27 |
+
is_next_random = False
|
28 |
+
if np_rng.random() < 0.5:
|
29 |
+
is_next_random = True
|
30 |
+
tokens_a, tokens_b = tokens_b, tokens_a
|
31 |
+
|
32 |
+
return tokens_a, tokens_b, is_next_random
|
fengshen/data/data_utils/token_type_utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
2 |
+
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
3 |
+
|
4 |
+
tokens = []
|
5 |
+
tokentypes = []
|
6 |
+
# [CLS].
|
7 |
+
tokens.append(cls_id)
|
8 |
+
tokentypes.append(0)
|
9 |
+
# Segment A.
|
10 |
+
for token in tokens_a:
|
11 |
+
tokens.append(token)
|
12 |
+
tokentypes.append(0)
|
13 |
+
# [SEP].
|
14 |
+
tokens.append(sep_id)
|
15 |
+
tokentypes.append(0)
|
16 |
+
# Segment B.
|
17 |
+
for token in tokens_b:
|
18 |
+
tokens.append(token)
|
19 |
+
tokentypes.append(1)
|
20 |
+
if tokens_b:
|
21 |
+
# [SEP].
|
22 |
+
tokens.append(sep_id)
|
23 |
+
tokentypes.append(1)
|
24 |
+
|
25 |
+
return tokens, tokentypes
|
fengshen/data/data_utils/truncate_utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
3 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
4 |
+
# print(len_a, len_b, max_num_tokens)
|
5 |
+
assert len_a > 0
|
6 |
+
if len_a + len_b <= max_num_tokens:
|
7 |
+
return False
|
8 |
+
while len_a + len_b > max_num_tokens:
|
9 |
+
if len_a > len_b:
|
10 |
+
len_a -= 1
|
11 |
+
tokens = tokens_a
|
12 |
+
else:
|
13 |
+
len_b -= 1
|
14 |
+
tokens = tokens_b
|
15 |
+
if np_rng.random() < 0.5:
|
16 |
+
del tokens[0]
|
17 |
+
else:
|
18 |
+
tokens.pop()
|
19 |
+
return True
|
fengshen/data/dreambooth_datasets/dreambooth_datasets.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
Copyright 2022 The International Digital Economy Academy (IDEA). CCNL team. All rights reserved.
|
4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
you may not use this file except in compliance with the License.
|
6 |
+
You may obtain a copy of the License at
|
7 |
+
|
8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
Unless required by applicable law or agreed to in writing, software
|
11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
@File : dreambooth_datasets.py
|
14 |
+
@Time : 2022/11/10 00:20
|
15 |
+
@Author : Gan Ruyi
|
16 |
+
@Version : 1.0
|
17 |
+
@Contact : ganruyi@idea.edu.cn
|
18 |
+
@License : (C)Copyright 2022-2023, CCNL-IDEA
|
19 |
+
'''
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
from torchvision import transforms
|
22 |
+
from PIL import Image
|
23 |
+
from pathlib import Path
|
24 |
+
|
25 |
+
|
26 |
+
def add_data_args(parent_args):
|
27 |
+
parser = parent_args.add_argument_group('taiyi stable diffusion data args')
|
28 |
+
parser.add_argument(
|
29 |
+
"--instance_data_dir",
|
30 |
+
type=str,
|
31 |
+
default=None,
|
32 |
+
required=True,
|
33 |
+
help="A folder containing the training data of instance images.",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--class_data_dir",
|
37 |
+
type=str,
|
38 |
+
default=None,
|
39 |
+
required=False,
|
40 |
+
help="A folder containing the training data of class images.",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--instance_prompt",
|
44 |
+
type=str,
|
45 |
+
default=None,
|
46 |
+
help="The prompt with identifier specifying the instance",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--class_prompt",
|
50 |
+
type=str,
|
51 |
+
default=None,
|
52 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--with_prior_preservation",
|
56 |
+
default=False,
|
57 |
+
action="store_true",
|
58 |
+
help="Flag to add prior preservation loss.",
|
59 |
+
)
|
60 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
|
61 |
+
parser.add_argument(
|
62 |
+
"--num_class_images",
|
63 |
+
type=int,
|
64 |
+
default=100,
|
65 |
+
help=(
|
66 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
67 |
+
" sampled with class_prompt."
|
68 |
+
),
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--resolution", type=int, default=512,
|
72 |
+
help=(
|
73 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
74 |
+
" resolution"
|
75 |
+
),
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--center_crop", action="store_true", default=False,
|
79 |
+
help="Whether to center crop images before resizing to resolution"
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
83 |
+
)
|
84 |
+
return parent_args
|
85 |
+
|
86 |
+
|
87 |
+
class DreamBoothDataset(Dataset):
|
88 |
+
"""
|
89 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
90 |
+
It pre-processes the images and the tokenizes prompts.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
instance_data_dir,
|
96 |
+
instance_prompt,
|
97 |
+
tokenizer,
|
98 |
+
class_data_dir=None,
|
99 |
+
class_prompt=None,
|
100 |
+
size=512,
|
101 |
+
center_crop=False,
|
102 |
+
):
|
103 |
+
self.size = size
|
104 |
+
self.center_crop = center_crop
|
105 |
+
self.tokenizer = tokenizer
|
106 |
+
|
107 |
+
self.instance_data_dir = Path(instance_data_dir)
|
108 |
+
if not self.instance_data_dir.exists():
|
109 |
+
raise ValueError("Instance images root doesn't exists.")
|
110 |
+
|
111 |
+
self.instance_images_path = list(Path(instance_data_dir).iterdir())
|
112 |
+
print(self.instance_images_path)
|
113 |
+
self.num_instance_images = len(self.instance_images_path)
|
114 |
+
self.instance_prompt = instance_prompt
|
115 |
+
self._length = self.num_instance_images
|
116 |
+
|
117 |
+
if class_data_dir is not None:
|
118 |
+
self.class_data_dir = Path(class_data_dir)
|
119 |
+
self.class_data_dir.mkdir(parents=True, exist_ok=True)
|
120 |
+
self.class_images_path = list(self.class_data_dir.iterdir())
|
121 |
+
self.num_class_images = len(self.class_images_path)
|
122 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
123 |
+
self.class_prompt = class_prompt
|
124 |
+
else:
|
125 |
+
self.class_data_dir = None
|
126 |
+
|
127 |
+
self.image_transforms = transforms.Compose(
|
128 |
+
[
|
129 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
130 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
131 |
+
transforms.ToTensor(),
|
132 |
+
transforms.Normalize([0.5], [0.5]),
|
133 |
+
]
|
134 |
+
)
|
135 |
+
|
136 |
+
def __len__(self):
|
137 |
+
return self._length
|
138 |
+
|
139 |
+
def __getitem__(self, index):
|
140 |
+
example = {}
|
141 |
+
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
142 |
+
if not instance_image.mode == "RGB":
|
143 |
+
instance_image = instance_image.convert("RGB")
|
144 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
145 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
146 |
+
self.instance_prompt,
|
147 |
+
padding="do_not_pad",
|
148 |
+
truncation=True,
|
149 |
+
max_length=64,
|
150 |
+
# max_length=self.tokenizer.model_max_length,
|
151 |
+
).input_ids
|
152 |
+
|
153 |
+
if self.class_data_dir:
|
154 |
+
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
155 |
+
if not class_image.mode == "RGB":
|
156 |
+
class_image = class_image.convert("RGB")
|
157 |
+
example["class_images"] = self.image_transforms(class_image)
|
158 |
+
example["class_prompt_ids"] = self.tokenizer(
|
159 |
+
self.class_prompt,
|
160 |
+
padding="do_not_pad",
|
161 |
+
truncation=True,
|
162 |
+
# max_length=self.tokenizer.model_max_length,
|
163 |
+
max_length=64,
|
164 |
+
).input_ids
|
165 |
+
|
166 |
+
return example
|
167 |
+
|
168 |
+
|
169 |
+
class PromptDataset(Dataset):
|
170 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
171 |
+
|
172 |
+
def __init__(self, prompt, num_samples):
|
173 |
+
self.prompt = prompt
|
174 |
+
self.num_samples = num_samples
|
175 |
+
|
176 |
+
def __len__(self):
|
177 |
+
return self.num_samples
|
178 |
+
|
179 |
+
def __getitem__(self, index):
|
180 |
+
example = {}
|
181 |
+
example["prompt"] = self.prompt
|
182 |
+
example["index"] = index
|
183 |
+
return example
|
fengshen/data/hubert/hubert_dataset.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from typing import Any, List, Optional, Union
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq.data import data_utils
|
17 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def add_data_specific_args(parent_args):
|
23 |
+
parser = parent_args.add_argument_group('Hubert Dataset')
|
24 |
+
parser.add_argument('--data', type=str)
|
25 |
+
parser.add_argument('--sample_rate', type=float, default=16000)
|
26 |
+
parser.add_argument('--label_dir', type=str)
|
27 |
+
parser.add_argument('--labels', type=str, nargs='+')
|
28 |
+
parser.add_argument('--label_rate', type=float)
|
29 |
+
parser.add_argument('--max_keep_size', type=int, default=None)
|
30 |
+
parser.add_argument('--min_sample_size', type=int)
|
31 |
+
parser.add_argument('--max_sample_size', type=int)
|
32 |
+
parser.add_argument('--pad_audio', type=bool)
|
33 |
+
parser.add_argument('--normalize', type=bool)
|
34 |
+
parser.add_argument('--random_crop', type=bool)
|
35 |
+
parser.add_argument('--single_target', type=bool, default=False)
|
36 |
+
return parent_args
|
37 |
+
|
38 |
+
|
39 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
40 |
+
n_long, n_short = 0, 0
|
41 |
+
names, inds, sizes = [], [], []
|
42 |
+
with open(manifest_path) as f:
|
43 |
+
root = f.readline().strip()
|
44 |
+
for ind, line in enumerate(f):
|
45 |
+
items = line.strip().split("\t")
|
46 |
+
assert len(items) == 2, line
|
47 |
+
sz = int(items[1])
|
48 |
+
if min_keep is not None and sz < min_keep:
|
49 |
+
n_short += 1
|
50 |
+
elif max_keep is not None and sz > max_keep:
|
51 |
+
n_long += 1
|
52 |
+
else:
|
53 |
+
names.append(items[0])
|
54 |
+
inds.append(ind)
|
55 |
+
sizes.append(sz)
|
56 |
+
tot = ind + 1
|
57 |
+
logger.info(
|
58 |
+
(
|
59 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
60 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
61 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
62 |
+
)
|
63 |
+
)
|
64 |
+
return root, names, inds, tot, sizes
|
65 |
+
|
66 |
+
|
67 |
+
def load_label(label_path, inds, tot):
|
68 |
+
with open(label_path) as f:
|
69 |
+
labels = [line.rstrip() for line in f]
|
70 |
+
assert (
|
71 |
+
len(labels) == tot
|
72 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
73 |
+
labels = [labels[i] for i in inds]
|
74 |
+
return labels
|
75 |
+
|
76 |
+
|
77 |
+
def load_label_offset(label_path, inds, tot):
|
78 |
+
with open(label_path) as f:
|
79 |
+
code_lengths = [len(line.encode("utf-8")) for line in f]
|
80 |
+
assert (
|
81 |
+
len(code_lengths) == tot
|
82 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
83 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
84 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
85 |
+
return offsets
|
86 |
+
|
87 |
+
|
88 |
+
def verify_label_lengths(
|
89 |
+
audio_sizes,
|
90 |
+
audio_rate,
|
91 |
+
label_path,
|
92 |
+
label_rate,
|
93 |
+
inds,
|
94 |
+
tot,
|
95 |
+
tol=0.1, # tolerance in seconds
|
96 |
+
):
|
97 |
+
if label_rate < 0:
|
98 |
+
logger.info(f"{label_path} is sequence label. skipped")
|
99 |
+
return
|
100 |
+
|
101 |
+
with open(label_path) as f:
|
102 |
+
lengths = [len(line.rstrip().split()) for line in f]
|
103 |
+
assert len(lengths) == tot
|
104 |
+
lengths = [lengths[i] for i in inds]
|
105 |
+
num_invalid = 0
|
106 |
+
for i, ind in enumerate(inds):
|
107 |
+
dur_from_audio = audio_sizes[i] / audio_rate
|
108 |
+
dur_from_label = lengths[i] / label_rate
|
109 |
+
if abs(dur_from_audio - dur_from_label) > tol:
|
110 |
+
logger.warning(
|
111 |
+
(
|
112 |
+
f"audio and label duration differ too much "
|
113 |
+
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
114 |
+
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
115 |
+
f"is correctly set (currently {label_rate}). "
|
116 |
+
f"num. of samples = {audio_sizes[i]}; "
|
117 |
+
f"label length = {lengths[i]}"
|
118 |
+
)
|
119 |
+
)
|
120 |
+
num_invalid += 1
|
121 |
+
if num_invalid > 0:
|
122 |
+
logger.warning(
|
123 |
+
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
class HubertDataset(FairseqDataset):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
manifest_path: str,
|
131 |
+
sample_rate: float,
|
132 |
+
label_paths: List[str],
|
133 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
134 |
+
pad_list: List[str],
|
135 |
+
eos_list: List[str],
|
136 |
+
label_processors: Optional[List[Any]] = None,
|
137 |
+
max_keep_sample_size: Optional[int] = None,
|
138 |
+
min_keep_sample_size: Optional[int] = None,
|
139 |
+
max_sample_size: Optional[int] = None,
|
140 |
+
shuffle: bool = True,
|
141 |
+
pad_audio: bool = False,
|
142 |
+
normalize: bool = False,
|
143 |
+
store_labels: bool = True,
|
144 |
+
random_crop: bool = False,
|
145 |
+
single_target: bool = False,
|
146 |
+
):
|
147 |
+
self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
|
148 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
149 |
+
)
|
150 |
+
self.sample_rate = sample_rate
|
151 |
+
self.shuffle = shuffle
|
152 |
+
self.random_crop = random_crop
|
153 |
+
|
154 |
+
self.num_labels = len(label_paths)
|
155 |
+
self.pad_list = pad_list
|
156 |
+
self.eos_list = eos_list
|
157 |
+
self.label_processors = label_processors
|
158 |
+
self.single_target = single_target
|
159 |
+
self.label_rates = (
|
160 |
+
[label_rates for _ in range(len(label_paths))]
|
161 |
+
if isinstance(label_rates, float)
|
162 |
+
else label_rates
|
163 |
+
)
|
164 |
+
self.store_labels = store_labels
|
165 |
+
if store_labels:
|
166 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
167 |
+
else:
|
168 |
+
self.label_paths = label_paths
|
169 |
+
self.label_offsets_list = [
|
170 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
171 |
+
]
|
172 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
173 |
+
for label_path, label_rate in zip(label_paths, self.label_rates):
|
174 |
+
verify_label_lengths(
|
175 |
+
self.sizes, sample_rate, label_path, label_rate, inds, tot
|
176 |
+
)
|
177 |
+
|
178 |
+
self.max_sample_size = (
|
179 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
180 |
+
)
|
181 |
+
self.pad_audio = pad_audio
|
182 |
+
self.normalize = normalize
|
183 |
+
logger.info(
|
184 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
185 |
+
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
186 |
+
)
|
187 |
+
|
188 |
+
def get_audio(self, index):
|
189 |
+
import soundfile as sf
|
190 |
+
|
191 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
192 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
193 |
+
wav = torch.from_numpy(wav).float()
|
194 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
195 |
+
return wav
|
196 |
+
|
197 |
+
def get_label(self, index, label_idx):
|
198 |
+
if self.store_labels:
|
199 |
+
label = self.label_list[label_idx][index]
|
200 |
+
else:
|
201 |
+
with open(self.label_paths[label_idx]) as f:
|
202 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
203 |
+
f.seek(offset_s)
|
204 |
+
label = f.read(offset_e - offset_s)
|
205 |
+
|
206 |
+
if self.label_processors is not None:
|
207 |
+
label = self.label_processors[label_idx](label)
|
208 |
+
return label
|
209 |
+
|
210 |
+
def get_labels(self, index):
|
211 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
212 |
+
|
213 |
+
def __getitem__(self, index):
|
214 |
+
wav = self.get_audio(index)
|
215 |
+
labels = self.get_labels(index)
|
216 |
+
return {"id": index, "source": wav, "label_list": labels}
|
217 |
+
|
218 |
+
def __len__(self):
|
219 |
+
return len(self.sizes)
|
220 |
+
|
221 |
+
def crop_to_max_size(self, wav, target_size):
|
222 |
+
size = len(wav)
|
223 |
+
diff = size - target_size
|
224 |
+
if diff <= 0:
|
225 |
+
return wav, 0
|
226 |
+
|
227 |
+
start, end = 0, target_size
|
228 |
+
if self.random_crop:
|
229 |
+
start = np.random.randint(0, diff + 1)
|
230 |
+
end = size - diff + start
|
231 |
+
return wav[start:end], start
|
232 |
+
|
233 |
+
def collater(self, samples):
|
234 |
+
# target = max(sizes) -> random_crop not used
|
235 |
+
# target = max_sample_size -> random_crop used for long
|
236 |
+
samples = [s for s in samples if s["source"] is not None]
|
237 |
+
if len(samples) == 0:
|
238 |
+
return {}
|
239 |
+
|
240 |
+
audios = [s["source"] for s in samples]
|
241 |
+
audio_sizes = [len(s) for s in audios]
|
242 |
+
if self.pad_audio:
|
243 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
244 |
+
else:
|
245 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
246 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
247 |
+
audios, audio_size
|
248 |
+
)
|
249 |
+
|
250 |
+
targets_by_label = [
|
251 |
+
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
252 |
+
]
|
253 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
254 |
+
targets_by_label, audio_size, audio_starts
|
255 |
+
)
|
256 |
+
|
257 |
+
net_input = {"source": collated_audios, "padding_mask": padding_mask}
|
258 |
+
batch = {
|
259 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
260 |
+
"net_input": net_input,
|
261 |
+
}
|
262 |
+
|
263 |
+
if self.single_target:
|
264 |
+
batch["target_lengths"] = lengths_list[0]
|
265 |
+
batch["ntokens"] = ntokens_list[0]
|
266 |
+
batch["target"] = targets_list[0]
|
267 |
+
else:
|
268 |
+
batch["target_lengths_list"] = lengths_list
|
269 |
+
batch["ntokens_list"] = ntokens_list
|
270 |
+
batch["target_list"] = targets_list
|
271 |
+
return batch
|
272 |
+
|
273 |
+
def collater_audio(self, audios, audio_size):
|
274 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
275 |
+
padding_mask = (
|
276 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
277 |
+
# if self.pad_audio else None
|
278 |
+
)
|
279 |
+
audio_starts = [0 for _ in audios]
|
280 |
+
for i, audio in enumerate(audios):
|
281 |
+
diff = len(audio) - audio_size
|
282 |
+
if diff == 0:
|
283 |
+
collated_audios[i] = audio
|
284 |
+
elif diff < 0:
|
285 |
+
assert self.pad_audio
|
286 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
287 |
+
padding_mask[i, diff:] = True
|
288 |
+
else:
|
289 |
+
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
290 |
+
audio, audio_size
|
291 |
+
)
|
292 |
+
return collated_audios, padding_mask, audio_starts
|
293 |
+
|
294 |
+
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
295 |
+
assert label_rate > 0
|
296 |
+
s2f = label_rate / self.sample_rate
|
297 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
298 |
+
frm_size = int(round(audio_size * s2f))
|
299 |
+
if not self.pad_audio:
|
300 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
301 |
+
frm_size = min(frm_size, *rem_size)
|
302 |
+
targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)]
|
303 |
+
logger.debug(f"audio_starts={audio_starts}")
|
304 |
+
logger.debug(f"frame_starts={frm_starts}")
|
305 |
+
logger.debug(f"frame_size={frm_size}")
|
306 |
+
|
307 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
308 |
+
ntokens = lengths.sum().item()
|
309 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
310 |
+
return targets, lengths, ntokens
|
311 |
+
|
312 |
+
def collater_seq_label(self, targets, pad):
|
313 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
314 |
+
ntokens = lengths.sum().item()
|
315 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
316 |
+
return targets, lengths, ntokens
|
317 |
+
|
318 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
319 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
320 |
+
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
321 |
+
for targets, label_rate, pad in itr:
|
322 |
+
if label_rate == -1.0:
|
323 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
324 |
+
else:
|
325 |
+
targets, lengths, ntokens = self.collater_frm_label(
|
326 |
+
targets, audio_size, audio_starts, label_rate, pad
|
327 |
+
)
|
328 |
+
targets_list.append(targets)
|
329 |
+
lengths_list.append(lengths)
|
330 |
+
ntokens_list.append(ntokens)
|
331 |
+
return targets_list, lengths_list, ntokens_list
|
332 |
+
|
333 |
+
def num_tokens(self, index):
|
334 |
+
return self.size(index)
|
335 |
+
|
336 |
+
def size(self, index):
|
337 |
+
if self.pad_audio:
|
338 |
+
return self.sizes[index]
|
339 |
+
return min(self.sizes[index], self.max_sample_size)
|
340 |
+
|
341 |
+
def ordered_indices(self):
|
342 |
+
if self.shuffle:
|
343 |
+
order = [np.random.permutation(len(self))]
|
344 |
+
else:
|
345 |
+
order = [np.arange(len(self))]
|
346 |
+
|
347 |
+
order.append(self.sizes)
|
348 |
+
return np.lexsort(order)[::-1]
|
349 |
+
|
350 |
+
def postprocess(self, wav, cur_sample_rate):
|
351 |
+
if wav.dim() == 2:
|
352 |
+
wav = wav.mean(-1)
|
353 |
+
assert wav.dim() == 1, wav.dim()
|
354 |
+
|
355 |
+
if cur_sample_rate != self.sample_rate:
|
356 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
357 |
+
|
358 |
+
if self.normalize:
|
359 |
+
with torch.no_grad():
|
360 |
+
wav = F.layer_norm(wav, wav.shape)
|
361 |
+
return wav
|
fengshen/data/megatron_dataloader/Makefile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
|
2 |
+
CPPFLAGS += $(shell python3 -m pybind11 --includes)
|
3 |
+
LIBNAME = helpers
|
4 |
+
LIBEXT = $(shell python3-config --extension-suffix)
|
5 |
+
|
6 |
+
default: $(LIBNAME)$(LIBEXT)
|
7 |
+
|
8 |
+
%$(LIBEXT): %.cpp
|
9 |
+
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
|
fengshen/data/megatron_dataloader/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import indexed_dataset
|
fengshen/data/megatron_dataloader/bart_dataset.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""BART Style dataset. Modified from fairseq."""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import math
|
6 |
+
import re
|
7 |
+
|
8 |
+
from fengshen.data.megatron_dataloader.dataset_utils import (
|
9 |
+
get_samples_mapping
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class BartDataset(torch.utils.data.Dataset):
|
14 |
+
def __init__(self, name, indexed_dataset, data_prefix,
|
15 |
+
num_epochs, max_num_samples, masked_lm_prob,
|
16 |
+
max_seq_length, short_seq_prob, seed, tokenizer, zh_tokenizer):
|
17 |
+
|
18 |
+
# Params to store.
|
19 |
+
self.name = name
|
20 |
+
self.seed = seed
|
21 |
+
self.masked_lm_prob = masked_lm_prob
|
22 |
+
self.max_seq_length = max_seq_length
|
23 |
+
|
24 |
+
# Dataset.
|
25 |
+
self.indexed_dataset = indexed_dataset
|
26 |
+
|
27 |
+
# Build the samples mapping.
|
28 |
+
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
|
29 |
+
data_prefix,
|
30 |
+
num_epochs,
|
31 |
+
max_num_samples,
|
32 |
+
self.max_seq_length - 3, # account for added tokens
|
33 |
+
short_seq_prob,
|
34 |
+
self.seed,
|
35 |
+
self.name,
|
36 |
+
False)
|
37 |
+
|
38 |
+
# Vocab stuff.
|
39 |
+
self.vocab_size = tokenizer.vocab_size
|
40 |
+
inv_vocab = {v: k for k, v in tokenizer.vocab.items()}
|
41 |
+
self.vocab_id_list = list(inv_vocab.keys())
|
42 |
+
self.vocab_id_to_token_dict = inv_vocab
|
43 |
+
self.cls_id = tokenizer.cls_token_id
|
44 |
+
self.sep_id = tokenizer.sep_token_id
|
45 |
+
self.mask_id = tokenizer.mask_token_id
|
46 |
+
self.pad_id = tokenizer.pad_token_id
|
47 |
+
self.tokenizer = tokenizer
|
48 |
+
|
49 |
+
seg_tokens = ['。', ';', ';', '!', '!', '?', '?']
|
50 |
+
seg_token_ids = []
|
51 |
+
for t in seg_tokens:
|
52 |
+
if t in tokenizer.vocab:
|
53 |
+
seg_token_ids.append(tokenizer.vocab[t])
|
54 |
+
else:
|
55 |
+
print('seg_token "{}" not in vocab'.format(t))
|
56 |
+
self.seg_token_ids = set(seg_token_ids)
|
57 |
+
|
58 |
+
self.zh_tokenizer = zh_tokenizer
|
59 |
+
|
60 |
+
# Denoising ratios
|
61 |
+
self.permute_sentence_ratio = 1.0
|
62 |
+
self.mask_ratio = masked_lm_prob # 0.15
|
63 |
+
self.random_ratio = 0.1
|
64 |
+
self.insert_ratio = 0.0
|
65 |
+
self.rotate_ratio = 0.0
|
66 |
+
self.mask_whole_word = 1
|
67 |
+
self.item_transform_func = None
|
68 |
+
|
69 |
+
self.mask_span_distribution = None
|
70 |
+
if False:
|
71 |
+
_lambda = 3 # Poisson lambda
|
72 |
+
|
73 |
+
lambda_to_the_k = 1
|
74 |
+
e_to_the_minus_lambda = math.exp(-_lambda)
|
75 |
+
k_factorial = 1
|
76 |
+
ps = []
|
77 |
+
for k in range(0, 128):
|
78 |
+
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
79 |
+
lambda_to_the_k *= _lambda
|
80 |
+
k_factorial *= k + 1
|
81 |
+
if ps[-1] < 0.0000001:
|
82 |
+
break
|
83 |
+
ps = torch.FloatTensor(ps)
|
84 |
+
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return self.samples_mapping.shape[0]
|
88 |
+
|
89 |
+
def __getitem__(self, idx):
|
90 |
+
start_idx, end_idx, seq_length = self.samples_mapping[idx]
|
91 |
+
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
|
92 |
+
# Note that this rng state should be numpy and not python since
|
93 |
+
# python randint is inclusive whereas the numpy one is exclusive.
|
94 |
+
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
|
95 |
+
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
|
96 |
+
return self.build_training_sample(sample, self.max_seq_length, np_rng)
|
97 |
+
|
98 |
+
def build_training_sample(self, sample, max_seq_length, np_rng):
|
99 |
+
"""Biuld training sample.
|
100 |
+
|
101 |
+
Arguments:
|
102 |
+
sample: A list of sentences in which each sentence is a list token ids.
|
103 |
+
max_seq_length: Desired sequence length.
|
104 |
+
np_rng: Random number genenrator. Note that this rng state should be
|
105 |
+
numpy and not python since python randint is inclusive for
|
106 |
+
the opper bound whereas the numpy one is exclusive.
|
107 |
+
"""
|
108 |
+
# permute sentences
|
109 |
+
full_stops = []
|
110 |
+
tokens = [self.cls_id]
|
111 |
+
for sent in sample:
|
112 |
+
for t in sent:
|
113 |
+
token = self.vocab_id_to_token_dict[t]
|
114 |
+
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0:
|
115 |
+
# 兼容erlangshen ##的方式做whole word mask
|
116 |
+
t = self.tokenizer.convert_tokens_to_ids(token[2:])
|
117 |
+
tokens.append(t)
|
118 |
+
if t in self.seg_token_ids:
|
119 |
+
tokens.append(self.sep_id)
|
120 |
+
if tokens[-1] != self.sep_id:
|
121 |
+
tokens.append(self.sep_id)
|
122 |
+
|
123 |
+
if len(tokens) > max_seq_length:
|
124 |
+
tokens = tokens[:max_seq_length]
|
125 |
+
tokens[-1] = self.sep_id
|
126 |
+
tokens = torch.LongTensor(tokens)
|
127 |
+
full_stops = (tokens == self.sep_id).long()
|
128 |
+
assert (max_seq_length - tokens.shape[0]) >= 0, (tokens.size(), tokens[-1], max_seq_length)
|
129 |
+
|
130 |
+
source, target = tokens, tokens[1:].clone()
|
131 |
+
use_decoder = 1
|
132 |
+
# if torch.rand(1).item() < 0.5:
|
133 |
+
# use_decoder = 0
|
134 |
+
|
135 |
+
if self.permute_sentence_ratio > 0.0 and use_decoder == 1:
|
136 |
+
source = self.permute_sentences(source, full_stops, self.permute_sentence_ratio)
|
137 |
+
|
138 |
+
if self.mask_ratio > 0.0:
|
139 |
+
replace_length = 1 if use_decoder else -1
|
140 |
+
mask_ratio = self.mask_ratio * 2 if use_decoder else self.mask_ratio
|
141 |
+
source = self.add_whole_word_mask(source, mask_ratio, replace_length)
|
142 |
+
|
143 |
+
if self.insert_ratio > 0.0:
|
144 |
+
raise NotImplementedError
|
145 |
+
source = self.add_insertion_noise(source, self.insert_ratio)
|
146 |
+
|
147 |
+
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
|
148 |
+
raise NotImplementedError
|
149 |
+
source = self.add_rolling_noise(source)
|
150 |
+
|
151 |
+
# there can additional changes to make:
|
152 |
+
if self.item_transform_func is not None:
|
153 |
+
source, target = self.item_transform_func(source, target)
|
154 |
+
|
155 |
+
assert (source >= 0).all()
|
156 |
+
# assert (source[1:-1] >= 1).all()
|
157 |
+
assert (source <= self.vocab_size).all()
|
158 |
+
assert source[0] == self.cls_id
|
159 |
+
assert source[-1] == self.sep_id
|
160 |
+
|
161 |
+
# tokenizer = get_tokenizer()
|
162 |
+
# print(' '.join(tokenizer.tokenizer.convert_ids_to_tokens(source)))
|
163 |
+
# print(tokenizer.detokenize(target))
|
164 |
+
# print(tokenizer.detokenize(source))
|
165 |
+
# print()
|
166 |
+
|
167 |
+
prev_output_tokens = torch.zeros_like(target)
|
168 |
+
prev_output_tokens[0] = self.sep_id # match the preprocessing in fairseq
|
169 |
+
prev_output_tokens[1:] = target[:-1]
|
170 |
+
|
171 |
+
# src_padding_length = max_seq_length - source.shape[0]
|
172 |
+
# tgt_padding_length = max_seq_length - target.shape[0]
|
173 |
+
# assert src_padding_length >= 0, (source.size(), source[-1], max_seq_length)
|
174 |
+
# assert tgt_padding_length >= 0, (target.size(), target[-1], max_seq_length)
|
175 |
+
source_ = torch.full((max_seq_length,), self.pad_id, dtype=torch.long)
|
176 |
+
source_[:source.shape[0]] = source
|
177 |
+
target_ = torch.full((max_seq_length,), -100, dtype=torch.long)
|
178 |
+
# decoder not need bos in the front
|
179 |
+
target_[:target.shape[0]] = target
|
180 |
+
prev_output_tokens_ = torch.full((max_seq_length,), self.pad_id, dtype=torch.long)
|
181 |
+
prev_output_tokens_[:prev_output_tokens.shape[0]] = prev_output_tokens
|
182 |
+
|
183 |
+
return {
|
184 |
+
"input_ids": source_,
|
185 |
+
"labels": target_,
|
186 |
+
# "decoder_input_ids": prev_output_tokens_,
|
187 |
+
"attention_mask": (source_ != self.pad_id).long()
|
188 |
+
}
|
189 |
+
|
190 |
+
def permute_sentences(self, source, full_stops, p=1.0):
|
191 |
+
# Tokens that are full stops, where the previous token is not
|
192 |
+
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
|
193 |
+
result = source.clone()
|
194 |
+
|
195 |
+
num_sentences = sentence_ends.size(0)
|
196 |
+
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
|
197 |
+
substitutions = torch.randperm(num_sentences)[:num_to_permute]
|
198 |
+
ordering = torch.arange(0, num_sentences)
|
199 |
+
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
|
200 |
+
|
201 |
+
# Ignore <bos> at start
|
202 |
+
index = 1
|
203 |
+
for i in ordering:
|
204 |
+
sentence = source[(sentence_ends[i - 1] if i > 0 else 1): sentence_ends[i]]
|
205 |
+
result[index: index + sentence.size(0)] = sentence
|
206 |
+
index += sentence.size(0)
|
207 |
+
return result
|
208 |
+
|
209 |
+
def word_starts_en(self, source):
|
210 |
+
if self.mask_whole_word is not None:
|
211 |
+
is_word_start = self.mask_whole_word.gather(0, source)
|
212 |
+
else:
|
213 |
+
is_word_start = torch.ones(source.size())
|
214 |
+
is_word_start[0] = 0
|
215 |
+
is_word_start[-1] = 0
|
216 |
+
return is_word_start
|
217 |
+
|
218 |
+
def word_starts(self, source):
|
219 |
+
if self.mask_whole_word is None:
|
220 |
+
is_word_start = torch.ones(source.size())
|
221 |
+
is_word_start[0] = 0
|
222 |
+
is_word_start[-1] = 0
|
223 |
+
return is_word_start
|
224 |
+
raw_tokens = [self.vocab_id_to_token_dict[i] for i in source.tolist()]
|
225 |
+
words = [raw_tokens[0]] + \
|
226 |
+
self.zh_tokenizer(''.join(raw_tokens[1:-1]), HMM=True) + [raw_tokens[-1]]
|
227 |
+
|
228 |
+
def _is_chinese_char(c):
|
229 |
+
"""Checks whether CP is the #codepoint of a CJK character."""
|
230 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
231 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
232 |
+
#
|
233 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
234 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
235 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
236 |
+
# space-separated words, so they are not treated specially and handled
|
237 |
+
# like the all of the other languages.
|
238 |
+
if len(c) > 1:
|
239 |
+
return all([_is_chinese_char(c_i) for c_i in c])
|
240 |
+
cp = ord(c)
|
241 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
242 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
243 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
244 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
245 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
246 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
247 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
248 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
249 |
+
return True
|
250 |
+
|
251 |
+
return False
|
252 |
+
|
253 |
+
def align_linear(atokens, btokens):
|
254 |
+
a2c = []
|
255 |
+
c2b = []
|
256 |
+
a2b = []
|
257 |
+
length = 0
|
258 |
+
for tok in atokens:
|
259 |
+
a2c.append([length + i for i in range(len(tok))])
|
260 |
+
length += len(tok)
|
261 |
+
for i, tok in enumerate(btokens):
|
262 |
+
c2b.extend([i for _ in range(len(tok))])
|
263 |
+
|
264 |
+
for i, amap in enumerate(a2c):
|
265 |
+
bmap = [c2b[ci] for ci in amap]
|
266 |
+
a2b.append(list(set(bmap)))
|
267 |
+
return a2b
|
268 |
+
|
269 |
+
raw_to_word_align = align_linear(raw_tokens, words)
|
270 |
+
is_word_start = torch.zeros(source.size())
|
271 |
+
word_starts = []
|
272 |
+
skip_cur_word = True
|
273 |
+
for i in range(1, len(raw_to_word_align)):
|
274 |
+
if raw_to_word_align[i-1] == raw_to_word_align[i]:
|
275 |
+
# not a word start, as they align to the same word
|
276 |
+
if not skip_cur_word and not _is_chinese_char(raw_tokens[i]):
|
277 |
+
word_starts.pop(-1)
|
278 |
+
skip_cur_word = True
|
279 |
+
continue
|
280 |
+
else:
|
281 |
+
is_word_start[i] = 1
|
282 |
+
if _is_chinese_char(raw_tokens[i]):
|
283 |
+
word_starts.append(i)
|
284 |
+
skip_cur_word = False
|
285 |
+
is_word_start[0] = 0
|
286 |
+
is_word_start[-1] = 0
|
287 |
+
word_starts = torch.tensor(word_starts).long().view(-1, 1)
|
288 |
+
return is_word_start, word_starts
|
289 |
+
|
290 |
+
def add_whole_word_mask(self, source, p, replace_length=1):
|
291 |
+
is_word_start, word_starts = self.word_starts(source)
|
292 |
+
num_to_mask_word = int(math.ceil(word_starts.size(0) * p))
|
293 |
+
num_to_mask_char = int(math.ceil(word_starts.size(0) * p * 0.1))
|
294 |
+
num_to_mask = num_to_mask_word + num_to_mask_char
|
295 |
+
if num_to_mask > word_starts.size(0):
|
296 |
+
word_starts = is_word_start.nonzero(as_tuple=False)
|
297 |
+
num_inserts = 0
|
298 |
+
if num_to_mask == 0:
|
299 |
+
return source
|
300 |
+
|
301 |
+
if self.mask_span_distribution is not None:
|
302 |
+
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
303 |
+
|
304 |
+
# Make sure we have enough to mask
|
305 |
+
cum_length = torch.cumsum(lengths, 0)
|
306 |
+
while cum_length[-1] < num_to_mask:
|
307 |
+
lengths = torch.cat(
|
308 |
+
[
|
309 |
+
lengths,
|
310 |
+
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
311 |
+
],
|
312 |
+
dim=0,
|
313 |
+
)
|
314 |
+
cum_length = torch.cumsum(lengths, 0)
|
315 |
+
|
316 |
+
# Trim to masking budget
|
317 |
+
i = 0
|
318 |
+
while cum_length[i] < num_to_mask:
|
319 |
+
i += 1
|
320 |
+
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
321 |
+
num_to_mask = i + 1
|
322 |
+
lengths = lengths[:num_to_mask]
|
323 |
+
|
324 |
+
# Handle 0-length mask (inserts) separately
|
325 |
+
lengths = lengths[lengths > 0]
|
326 |
+
num_inserts = num_to_mask - lengths.size(0)
|
327 |
+
num_to_mask -= num_inserts
|
328 |
+
if num_to_mask == 0:
|
329 |
+
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
330 |
+
|
331 |
+
assert (lengths > 0).all()
|
332 |
+
else:
|
333 |
+
lengths = torch.ones((num_to_mask,)).long()
|
334 |
+
assert is_word_start[-1] == 0
|
335 |
+
indices = word_starts[
|
336 |
+
torch.randperm(word_starts.size(0))[:num_to_mask]
|
337 |
+
].squeeze(1)
|
338 |
+
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
339 |
+
source_length = source.size(0)
|
340 |
+
assert source_length - 1 not in indices
|
341 |
+
to_keep = torch.ones(source_length, dtype=torch.bool)
|
342 |
+
is_word_start[
|
343 |
+
-1
|
344 |
+
] = 255 # acts as a long length, so spans don't go over the end of doc
|
345 |
+
if replace_length == 0:
|
346 |
+
to_keep[indices] = 0
|
347 |
+
else:
|
348 |
+
# keep index, but replace it with [MASK]
|
349 |
+
# print(source.size(), word_starts.size(), indices.size(), mask_random.size())
|
350 |
+
source[indices] = self.mask_id
|
351 |
+
source[indices[mask_random]] = torch.randint(
|
352 |
+
1, self.vocab_size, size=(mask_random.sum(),)
|
353 |
+
)
|
354 |
+
# sorted_indices = torch.sort(indices)[0]
|
355 |
+
# continue_mask_pos = ((sorted_indices + 1)[:-1] == sorted_indices[1:])
|
356 |
+
# continue_mask_indices = sorted_indices[1:][continue_mask_pos]
|
357 |
+
# to_keep[continue_mask_indices] = 0
|
358 |
+
|
359 |
+
# for char indices, we already masked, the following loop handles word mask
|
360 |
+
indices = indices[:num_to_mask_word]
|
361 |
+
mask_random = mask_random[:num_to_mask_word]
|
362 |
+
if self.mask_span_distribution is not None:
|
363 |
+
assert len(lengths.size()) == 1
|
364 |
+
assert lengths.size() == indices.size()
|
365 |
+
lengths -= 1
|
366 |
+
while indices.size(0) > 0:
|
367 |
+
assert lengths.size() == indices.size()
|
368 |
+
lengths -= is_word_start[indices + 1].long()
|
369 |
+
uncompleted = lengths >= 0
|
370 |
+
indices = indices[uncompleted] + 1
|
371 |
+
mask_random = mask_random[uncompleted]
|
372 |
+
lengths = lengths[uncompleted]
|
373 |
+
if replace_length != -1:
|
374 |
+
# delete token
|
375 |
+
to_keep[indices] = 0
|
376 |
+
else:
|
377 |
+
# keep index, but replace it with [MASK]
|
378 |
+
source[indices] = self.mask_id
|
379 |
+
source[indices[mask_random]] = torch.randint(
|
380 |
+
1, self.vocab_size, size=(mask_random.sum(),)
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
# A bit faster when all lengths are 1
|
384 |
+
while indices.size(0) > 0:
|
385 |
+
uncompleted = is_word_start[indices + 1] == 0
|
386 |
+
indices = indices[uncompleted] + 1
|
387 |
+
mask_random = mask_random[uncompleted]
|
388 |
+
if replace_length != -1:
|
389 |
+
# delete token
|
390 |
+
to_keep[indices] = 0
|
391 |
+
else:
|
392 |
+
# keep index, but replace it with [MASK]
|
393 |
+
source[indices] = self.mask_id
|
394 |
+
source[indices[mask_random]] = torch.randint(
|
395 |
+
1, self.vocab_size, size=(mask_random.sum(),)
|
396 |
+
)
|
397 |
+
|
398 |
+
assert source_length - 1 not in indices
|
399 |
+
|
400 |
+
source = source[to_keep]
|
401 |
+
|
402 |
+
if num_inserts > 0:
|
403 |
+
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
404 |
+
|
405 |
+
return source
|
406 |
+
|
407 |
+
def add_permuted_noise(self, tokens, p):
|
408 |
+
num_words = len(tokens)
|
409 |
+
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
|
410 |
+
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
|
411 |
+
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
|
412 |
+
return tokens
|
413 |
+
|
414 |
+
def add_rolling_noise(self, tokens):
|
415 |
+
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
|
416 |
+
tokens = torch.cat(
|
417 |
+
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
|
418 |
+
dim=0,
|
419 |
+
)
|
420 |
+
return tokens
|
421 |
+
|
422 |
+
def add_insertion_noise(self, tokens, p):
|
423 |
+
if p == 0.0:
|
424 |
+
return tokens
|
425 |
+
|
426 |
+
num_tokens = len(tokens)
|
427 |
+
n = int(math.ceil(num_tokens * p))
|
428 |
+
|
429 |
+
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
430 |
+
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
431 |
+
noise_mask[noise_indices] = 1
|
432 |
+
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
433 |
+
|
434 |
+
num_random = int(math.ceil(n * self.random_ratio))
|
435 |
+
result[noise_indices[num_random:]] = self.mask_id
|
436 |
+
result[noise_indices[:num_random]] = torch.randint(
|
437 |
+
low=1, high=self.vocab_size, size=(num_random,)
|
438 |
+
)
|
439 |
+
|
440 |
+
result[~noise_mask] = tokens
|
441 |
+
|
442 |
+
assert (result >= 0).all()
|
443 |
+
return result
|
fengshen/data/megatron_dataloader/bert_dataset.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""BERT Style dataset."""
|
17 |
+
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from fengshen.data.megatron_dataloader.dataset_utils import (
|
23 |
+
get_samples_mapping,
|
24 |
+
get_a_and_b_segments,
|
25 |
+
create_masked_lm_predictions,
|
26 |
+
create_tokens_and_tokentypes,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class BertDataset(torch.utils.data.Dataset):
|
31 |
+
|
32 |
+
def __init__(self, name, indexed_dataset, data_prefix,
|
33 |
+
num_epochs, max_num_samples, masked_lm_prob,
|
34 |
+
max_seq_length, short_seq_prob, seed, binary_head, tokenizer, masking_style):
|
35 |
+
# Params to store.
|
36 |
+
self.name = name
|
37 |
+
self.seed = seed
|
38 |
+
self.masked_lm_prob = masked_lm_prob
|
39 |
+
self.max_seq_length = max_seq_length
|
40 |
+
self.short_seq_prob = short_seq_prob
|
41 |
+
self.binary_head = binary_head
|
42 |
+
self.masking_style = masking_style
|
43 |
+
|
44 |
+
# Dataset.
|
45 |
+
self.indexed_dataset = indexed_dataset
|
46 |
+
|
47 |
+
# Build the samples mapping.
|
48 |
+
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
|
49 |
+
data_prefix,
|
50 |
+
num_epochs,
|
51 |
+
max_num_samples,
|
52 |
+
# account for added tokens
|
53 |
+
self.max_seq_length - 3,
|
54 |
+
short_seq_prob,
|
55 |
+
self.seed,
|
56 |
+
self.name,
|
57 |
+
self.binary_head)
|
58 |
+
inv_vocab = {v: k for k, v in tokenizer.vocab.items()}
|
59 |
+
self.vocab_id_list = list(inv_vocab.keys())
|
60 |
+
self.vocab_id_to_token_dict = inv_vocab
|
61 |
+
self.cls_id = tokenizer.cls_token_id
|
62 |
+
self.sep_id = tokenizer.sep_token_id
|
63 |
+
self.mask_id = tokenizer.mask_token_id
|
64 |
+
self.pad_id = tokenizer.pad_token_id
|
65 |
+
self.tokenizer = tokenizer
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
return self.samples_mapping.shape[0]
|
69 |
+
|
70 |
+
def __getitem__(self, idx):
|
71 |
+
start_idx, end_idx, seq_length = self.samples_mapping[idx]
|
72 |
+
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
|
73 |
+
# Note that this rng state should be numpy and not python since
|
74 |
+
# python randint is inclusive whereas the numpy one is exclusive.
|
75 |
+
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
|
76 |
+
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
|
77 |
+
return build_training_sample(sample, seq_length,
|
78 |
+
self.max_seq_length, # needed for padding
|
79 |
+
self.vocab_id_list,
|
80 |
+
self.vocab_id_to_token_dict,
|
81 |
+
self.cls_id, self.sep_id,
|
82 |
+
self.mask_id, self.pad_id,
|
83 |
+
self.masked_lm_prob, np_rng,
|
84 |
+
self.binary_head,
|
85 |
+
tokenizer=self.tokenizer,
|
86 |
+
masking_style=self.masking_style)
|
87 |
+
|
88 |
+
|
89 |
+
def build_training_sample(sample,
|
90 |
+
target_seq_length, max_seq_length,
|
91 |
+
vocab_id_list, vocab_id_to_token_dict,
|
92 |
+
cls_id, sep_id, mask_id, pad_id,
|
93 |
+
masked_lm_prob, np_rng, binary_head,
|
94 |
+
tokenizer,
|
95 |
+
masking_style='bert'):
|
96 |
+
"""Biuld training sample.
|
97 |
+
|
98 |
+
Arguments:
|
99 |
+
sample: A list of sentences in which each sentence is a list token ids.
|
100 |
+
target_seq_length: Desired sequence length.
|
101 |
+
max_seq_length: Maximum length of the sequence. All values are padded to
|
102 |
+
this length.
|
103 |
+
vocab_id_list: List of vocabulary ids. Used to pick a random id.
|
104 |
+
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
|
105 |
+
cls_id: Start of example id.
|
106 |
+
sep_id: Separator id.
|
107 |
+
mask_id: Mask token id.
|
108 |
+
pad_id: Padding token id.
|
109 |
+
masked_lm_prob: Probability to mask tokens.
|
110 |
+
np_rng: Random number genenrator. Note that this rng state should be
|
111 |
+
numpy and not python since python randint is inclusive for
|
112 |
+
the opper bound whereas the numpy one is exclusive.
|
113 |
+
"""
|
114 |
+
|
115 |
+
if binary_head:
|
116 |
+
# We assume that we have at least two sentences in the sample
|
117 |
+
assert len(sample) > 1
|
118 |
+
assert target_seq_length <= max_seq_length
|
119 |
+
|
120 |
+
# Divide sample into two segments (A and B).
|
121 |
+
if binary_head:
|
122 |
+
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
|
123 |
+
np_rng)
|
124 |
+
else:
|
125 |
+
tokens_a = []
|
126 |
+
for j in range(len(sample)):
|
127 |
+
tokens_a.extend(sample[j])
|
128 |
+
tokens_b = []
|
129 |
+
is_next_random = False
|
130 |
+
|
131 |
+
if len(tokens_a) >= max_seq_length-3:
|
132 |
+
tokens_a = tokens_a[:max_seq_length-3]
|
133 |
+
|
134 |
+
# Truncate to `target_sequence_length`.
|
135 |
+
max_num_tokens = target_seq_length
|
136 |
+
''''
|
137 |
+
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
|
138 |
+
len(tokens_b), max_num_tokens, np_rng)
|
139 |
+
'''
|
140 |
+
|
141 |
+
# Build tokens and toketypes.
|
142 |
+
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
|
143 |
+
cls_id, sep_id)
|
144 |
+
# Masking.
|
145 |
+
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
146 |
+
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
|
147 |
+
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
|
148 |
+
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
|
149 |
+
tokenizer=tokenizer,
|
150 |
+
masking_style=masking_style)
|
151 |
+
|
152 |
+
# Padding.
|
153 |
+
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
|
154 |
+
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
155 |
+
masked_labels, pad_id, max_seq_length)
|
156 |
+
|
157 |
+
train_sample = {
|
158 |
+
'input_ids': tokens_np,
|
159 |
+
'token_type_ids': tokentypes_np,
|
160 |
+
'labels': labels_np,
|
161 |
+
'next_sentence_label': int(is_next_random),
|
162 |
+
'attention_mask': padding_mask_np}
|
163 |
+
return train_sample
|
164 |
+
|
165 |
+
|
166 |
+
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
167 |
+
masked_labels, pad_id, max_seq_length):
|
168 |
+
"""Pad sequences and convert them to numpy."""
|
169 |
+
|
170 |
+
# Some checks.
|
171 |
+
num_tokens = len(tokens)
|
172 |
+
padding_length = max_seq_length - num_tokens
|
173 |
+
assert padding_length >= 0
|
174 |
+
assert len(tokentypes) == num_tokens
|
175 |
+
assert len(masked_positions) == len(masked_labels)
|
176 |
+
|
177 |
+
# Tokens and token types.
|
178 |
+
filler = [pad_id] * padding_length
|
179 |
+
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
180 |
+
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
181 |
+
|
182 |
+
# Padding mask.
|
183 |
+
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
184 |
+
dtype=np.int64)
|
185 |
+
|
186 |
+
# Lables and loss mask.
|
187 |
+
labels = [-100] * max_seq_length
|
188 |
+
loss_mask = [0] * max_seq_length
|
189 |
+
for i in range(len(masked_positions)):
|
190 |
+
assert masked_positions[i] < num_tokens
|
191 |
+
labels[masked_positions[i]] = masked_labels[i]
|
192 |
+
loss_mask[masked_positions[i]] = 1
|
193 |
+
labels_np = np.array(labels, dtype=np.int64)
|
194 |
+
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
195 |
+
|
196 |
+
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
fengshen/data/megatron_dataloader/blendable_dataset.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Blendable dataset."""
|
17 |
+
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from fengshen.data.megatron_dataloader.utils import print_rank_0
|
24 |
+
|
25 |
+
|
26 |
+
class BlendableDataset(torch.utils.data.Dataset):
|
27 |
+
|
28 |
+
def __init__(self, datasets, weights):
|
29 |
+
|
30 |
+
self.datasets = datasets
|
31 |
+
num_datasets = len(datasets)
|
32 |
+
assert num_datasets == len(weights)
|
33 |
+
|
34 |
+
self.size = 0
|
35 |
+
for dataset in self.datasets:
|
36 |
+
self.size += len(dataset)
|
37 |
+
|
38 |
+
# Normalize weights.
|
39 |
+
weights = np.array(weights, dtype=np.float64)
|
40 |
+
sum_weights = np.sum(weights)
|
41 |
+
assert sum_weights > 0.0
|
42 |
+
weights /= sum_weights
|
43 |
+
|
44 |
+
# Build indecies.
|
45 |
+
start_time = time.time()
|
46 |
+
assert num_datasets < 255
|
47 |
+
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
|
48 |
+
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
49 |
+
|
50 |
+
from fengshen.data.megatron_dataloader import helpers
|
51 |
+
helpers.build_blending_indices(self.dataset_index,
|
52 |
+
self.dataset_sample_index,
|
53 |
+
weights, num_datasets, self.size,
|
54 |
+
torch.distributed.get_rank() == 0)
|
55 |
+
print_rank_0('> elapsed time for building blendable dataset indices: '
|
56 |
+
'{:.2f} (sec)'.format(time.time() - start_time))
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return self.size
|
60 |
+
|
61 |
+
def __getitem__(self, idx):
|
62 |
+
dataset_idx = self.dataset_index[idx]
|
63 |
+
sample_idx = self.dataset_sample_index[idx]
|
64 |
+
return self.datasets[dataset_idx][sample_idx]
|
fengshen/data/megatron_dataloader/dataset_utils.py
ADDED
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
# Most of the code here has been copied from:
|
18 |
+
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
19 |
+
# with some modifications.
|
20 |
+
|
21 |
+
import math
|
22 |
+
import time
|
23 |
+
import collections
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import re
|
27 |
+
|
28 |
+
from fengshen.data.megatron_dataloader.utils import (
|
29 |
+
print_rank_0
|
30 |
+
)
|
31 |
+
from fengshen.data.megatron_dataloader.blendable_dataset import BlendableDataset
|
32 |
+
from fengshen.data.megatron_dataloader.indexed_dataset import make_dataset as make_indexed_dataset
|
33 |
+
|
34 |
+
DSET_TYPE_BERT = 'standard_bert'
|
35 |
+
DSET_TYPE_ICT = 'ict'
|
36 |
+
DSET_TYPE_T5 = 't5'
|
37 |
+
DSET_TYPE_BERT_CN_WWM = 'bert_cn_wwm'
|
38 |
+
DSET_TYPE_BART = 'bart'
|
39 |
+
DSET_TYPE_COCOLM = 'coco_lm'
|
40 |
+
|
41 |
+
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT,
|
42 |
+
DSET_TYPE_T5, DSET_TYPE_BERT_CN_WWM,
|
43 |
+
DSET_TYPE_BART, DSET_TYPE_COCOLM]
|
44 |
+
|
45 |
+
|
46 |
+
def get_datasets_weights_and_num_samples(data_prefix,
|
47 |
+
train_valid_test_num_samples):
|
48 |
+
|
49 |
+
# The data prefix should be in the format of:
|
50 |
+
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
51 |
+
assert len(data_prefix) % 2 == 0
|
52 |
+
num_datasets = len(data_prefix) // 2
|
53 |
+
weights = [0] * num_datasets
|
54 |
+
prefixes = [0] * num_datasets
|
55 |
+
for i in range(num_datasets):
|
56 |
+
weights[i] = float(data_prefix[2 * i])
|
57 |
+
prefixes[i] = (data_prefix[2 * i + 1]).strip()
|
58 |
+
# Normalize weights
|
59 |
+
weight_sum = 0.0
|
60 |
+
for weight in weights:
|
61 |
+
weight_sum += weight
|
62 |
+
assert weight_sum > 0.0
|
63 |
+
weights = [weight / weight_sum for weight in weights]
|
64 |
+
|
65 |
+
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
|
66 |
+
# not uniformly distribute the number of samples, we still have
|
67 |
+
# samples left to feed to the network.
|
68 |
+
datasets_train_valid_test_num_samples = []
|
69 |
+
for weight in weights:
|
70 |
+
datasets_train_valid_test_num_samples.append(
|
71 |
+
[int(math.ceil(val * weight * 1.005))
|
72 |
+
for val in train_valid_test_num_samples])
|
73 |
+
|
74 |
+
return prefixes, weights, datasets_train_valid_test_num_samples
|
75 |
+
|
76 |
+
|
77 |
+
def compile_helper():
|
78 |
+
"""Compile helper function ar runtime. Make sure this
|
79 |
+
is invoked on a single process."""
|
80 |
+
import os
|
81 |
+
import subprocess
|
82 |
+
path = os.path.abspath(os.path.dirname(__file__))
|
83 |
+
ret = subprocess.run(['make', '-C', path])
|
84 |
+
if ret.returncode != 0:
|
85 |
+
print("Making C++ dataset helpers module failed, exiting.")
|
86 |
+
import sys
|
87 |
+
sys.exit(1)
|
88 |
+
|
89 |
+
|
90 |
+
def get_a_and_b_segments(sample, np_rng):
|
91 |
+
"""Divide sample into a and b segments."""
|
92 |
+
|
93 |
+
# Number of sentences in the sample.
|
94 |
+
n_sentences = len(sample)
|
95 |
+
# Make sure we always have two sentences.
|
96 |
+
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
|
97 |
+
|
98 |
+
# First part:
|
99 |
+
# `a_end` is how many sentences go into the `A`.
|
100 |
+
a_end = 1
|
101 |
+
if n_sentences >= 3:
|
102 |
+
# Note that randin in numpy is exclusive.
|
103 |
+
a_end = np_rng.randint(1, n_sentences)
|
104 |
+
tokens_a = []
|
105 |
+
for j in range(a_end):
|
106 |
+
tokens_a.extend(sample[j])
|
107 |
+
|
108 |
+
# Second part:
|
109 |
+
tokens_b = []
|
110 |
+
for j in range(a_end, n_sentences):
|
111 |
+
tokens_b.extend(sample[j])
|
112 |
+
|
113 |
+
# Random next:
|
114 |
+
is_next_random = False
|
115 |
+
if np_rng.random() < 0.5:
|
116 |
+
is_next_random = True
|
117 |
+
tokens_a, tokens_b = tokens_b, tokens_a
|
118 |
+
|
119 |
+
return tokens_a, tokens_b, is_next_random
|
120 |
+
|
121 |
+
|
122 |
+
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
123 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
124 |
+
# print(len_a, len_b, max_num_tokens)
|
125 |
+
assert len_a > 0
|
126 |
+
if len_a + len_b <= max_num_tokens:
|
127 |
+
return False
|
128 |
+
while len_a + len_b > max_num_tokens:
|
129 |
+
if len_a > len_b:
|
130 |
+
len_a -= 1
|
131 |
+
tokens = tokens_a
|
132 |
+
else:
|
133 |
+
len_b -= 1
|
134 |
+
tokens = tokens_b
|
135 |
+
if np_rng.random() < 0.5:
|
136 |
+
del tokens[0]
|
137 |
+
else:
|
138 |
+
tokens.pop()
|
139 |
+
return True
|
140 |
+
|
141 |
+
|
142 |
+
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
143 |
+
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
144 |
+
|
145 |
+
tokens = []
|
146 |
+
tokentypes = []
|
147 |
+
# [CLS].
|
148 |
+
tokens.append(cls_id)
|
149 |
+
tokentypes.append(0)
|
150 |
+
# Segment A.
|
151 |
+
for token in tokens_a:
|
152 |
+
tokens.append(token)
|
153 |
+
tokentypes.append(0)
|
154 |
+
# [SEP].
|
155 |
+
tokens.append(sep_id)
|
156 |
+
tokentypes.append(0)
|
157 |
+
# Segment B.
|
158 |
+
for token in tokens_b:
|
159 |
+
tokens.append(token)
|
160 |
+
tokentypes.append(1)
|
161 |
+
if tokens_b:
|
162 |
+
# [SEP].
|
163 |
+
tokens.append(sep_id)
|
164 |
+
tokentypes.append(1)
|
165 |
+
|
166 |
+
return tokens, tokentypes
|
167 |
+
|
168 |
+
|
169 |
+
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
170 |
+
["index", "label"])
|
171 |
+
|
172 |
+
|
173 |
+
def is_start_piece(piece):
|
174 |
+
"""Check if the current word piece is the starting piece (BERT)."""
|
175 |
+
# When a word has been split into
|
176 |
+
# WordPieces, the first token does not have any marker and any subsequence
|
177 |
+
# tokens are prefixed with ##. So whenever we see the ## token, we
|
178 |
+
# append it to the previous set of word indexes.
|
179 |
+
return not piece.startswith("##")
|
180 |
+
|
181 |
+
|
182 |
+
def create_masked_lm_predictions(tokens,
|
183 |
+
vocab_id_list, vocab_id_to_token_dict,
|
184 |
+
masked_lm_prob,
|
185 |
+
cls_id, sep_id, mask_id,
|
186 |
+
max_predictions_per_seq,
|
187 |
+
np_rng,
|
188 |
+
tokenizer,
|
189 |
+
max_ngrams=3,
|
190 |
+
do_whole_word_mask=True,
|
191 |
+
favor_longer_ngram=False,
|
192 |
+
do_permutation=False,
|
193 |
+
geometric_dist=False,
|
194 |
+
masking_style="bert",
|
195 |
+
zh_tokenizer=None):
|
196 |
+
"""Creates the predictions for the masked LM objective.
|
197 |
+
Note: Tokens here are vocab ids and not text tokens."""
|
198 |
+
|
199 |
+
cand_indexes = []
|
200 |
+
# Note(mingdachen): We create a list for recording if the piece is
|
201 |
+
# the starting piece of current token, where 1 means true, so that
|
202 |
+
# on-the-fly whole word masking is possible.
|
203 |
+
token_boundary = [0] * len(tokens)
|
204 |
+
|
205 |
+
# 如果没有指定中文分词器,那就直接按##算
|
206 |
+
if zh_tokenizer is None:
|
207 |
+
for (i, token) in enumerate(tokens):
|
208 |
+
if token == cls_id or token == sep_id:
|
209 |
+
token_boundary[i] = 1
|
210 |
+
continue
|
211 |
+
# Whole Word Masking means that if we mask all of the wordpieces
|
212 |
+
# corresponding to an original word.
|
213 |
+
#
|
214 |
+
# Note that Whole Word Masking does *not* change the training code
|
215 |
+
# at all -- we still predict each WordPiece independently, softmaxed
|
216 |
+
# over the entire vocabulary.
|
217 |
+
if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
218 |
+
not is_start_piece(vocab_id_to_token_dict[token])):
|
219 |
+
cand_indexes[-1].append(i)
|
220 |
+
else:
|
221 |
+
cand_indexes.append([i])
|
222 |
+
if is_start_piece(vocab_id_to_token_dict[token]):
|
223 |
+
token_boundary[i] = 1
|
224 |
+
else:
|
225 |
+
# 如果指定了中文分词器,那就先用分词器分词,然后再进行判断
|
226 |
+
# 获取去掉CLS SEP的原始文本
|
227 |
+
raw_tokens = []
|
228 |
+
for t in tokens:
|
229 |
+
if t != cls_id and t != sep_id:
|
230 |
+
raw_tokens.append(t)
|
231 |
+
raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens]
|
232 |
+
# 分词然后获取每次字开头的最长词的长度
|
233 |
+
word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True))
|
234 |
+
word_length_dict = {}
|
235 |
+
for w in word_list:
|
236 |
+
if len(w) < 1:
|
237 |
+
continue
|
238 |
+
if w[0] not in word_length_dict:
|
239 |
+
word_length_dict[w[0]] = len(w)
|
240 |
+
elif word_length_dict[w[0]] < len(w):
|
241 |
+
word_length_dict[w[0]] = len(w)
|
242 |
+
i = 0
|
243 |
+
# 从词表里面检索
|
244 |
+
while i < len(tokens):
|
245 |
+
token_id = tokens[i]
|
246 |
+
token = vocab_id_to_token_dict[token_id]
|
247 |
+
if len(token) == 0 or token_id == cls_id or token_id == sep_id:
|
248 |
+
token_boundary[i] = 1
|
249 |
+
i += 1
|
250 |
+
continue
|
251 |
+
word_max_length = 1
|
252 |
+
if token[0] in word_length_dict:
|
253 |
+
word_max_length = word_length_dict[token[0]]
|
254 |
+
j = 0
|
255 |
+
word = ''
|
256 |
+
word_end = i+1
|
257 |
+
# 兼容以前##的形式,如果后面的词是##开头的,那么直接把后面的拼到前面当作一个词
|
258 |
+
old_style = False
|
259 |
+
while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'):
|
260 |
+
old_style = True
|
261 |
+
word_end += 1
|
262 |
+
if not old_style:
|
263 |
+
while j < word_max_length and i+j < len(tokens):
|
264 |
+
cur_token = tokens[i+j]
|
265 |
+
word += vocab_id_to_token_dict[cur_token]
|
266 |
+
j += 1
|
267 |
+
if word in word_list:
|
268 |
+
word_end = i+j
|
269 |
+
cand_indexes.append([p for p in range(i, word_end)])
|
270 |
+
token_boundary[i] = 1
|
271 |
+
i = word_end
|
272 |
+
|
273 |
+
output_tokens = list(tokens)
|
274 |
+
# add by ganruyi
|
275 |
+
if masking_style == 'bert-cn-wwm':
|
276 |
+
# if non chinese is False, that means it is chinese
|
277 |
+
# then try to remove "##" which is added previously
|
278 |
+
new_token_ids = []
|
279 |
+
for token_id in output_tokens:
|
280 |
+
token = tokenizer.convert_ids_to_tokens([token_id])[0]
|
281 |
+
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0:
|
282 |
+
token = token[2:]
|
283 |
+
new_token_id = tokenizer.convert_tokens_to_ids([token])[
|
284 |
+
0]
|
285 |
+
new_token_ids.append(new_token_id)
|
286 |
+
output_tokens = new_token_ids
|
287 |
+
|
288 |
+
masked_lm_positions = []
|
289 |
+
masked_lm_labels = []
|
290 |
+
|
291 |
+
if masked_lm_prob == 0:
|
292 |
+
return (output_tokens, masked_lm_positions,
|
293 |
+
masked_lm_labels, token_boundary)
|
294 |
+
|
295 |
+
num_to_predict = min(max_predictions_per_seq,
|
296 |
+
max(1, int(round(len(tokens) * masked_lm_prob))))
|
297 |
+
|
298 |
+
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
299 |
+
if not geometric_dist:
|
300 |
+
# Note(mingdachen):
|
301 |
+
# By default, we set the probilities to favor shorter ngram sequences.
|
302 |
+
pvals = 1. / np.arange(1, max_ngrams + 1)
|
303 |
+
pvals /= pvals.sum(keepdims=True)
|
304 |
+
if favor_longer_ngram:
|
305 |
+
pvals = pvals[::-1]
|
306 |
+
# 获取一个ngram的idx,对于每个word,记录他的ngram的word
|
307 |
+
ngram_indexes = []
|
308 |
+
for idx in range(len(cand_indexes)):
|
309 |
+
ngram_index = []
|
310 |
+
for n in ngrams:
|
311 |
+
ngram_index.append(cand_indexes[idx:idx + n])
|
312 |
+
ngram_indexes.append(ngram_index)
|
313 |
+
|
314 |
+
np_rng.shuffle(ngram_indexes)
|
315 |
+
|
316 |
+
(masked_lms, masked_spans) = ([], [])
|
317 |
+
covered_indexes = set()
|
318 |
+
for cand_index_set in ngram_indexes:
|
319 |
+
if len(masked_lms) >= num_to_predict:
|
320 |
+
break
|
321 |
+
if not cand_index_set:
|
322 |
+
continue
|
323 |
+
# Note(mingdachen):
|
324 |
+
# Skip current piece if they are covered in lm masking or previous ngrams.
|
325 |
+
for index_set in cand_index_set[0]:
|
326 |
+
for index in index_set:
|
327 |
+
if index in covered_indexes:
|
328 |
+
continue
|
329 |
+
|
330 |
+
if not geometric_dist:
|
331 |
+
n = np_rng.choice(ngrams[:len(cand_index_set)],
|
332 |
+
p=pvals[:len(cand_index_set)] /
|
333 |
+
pvals[:len(cand_index_set)].sum(keepdims=True))
|
334 |
+
else:
|
335 |
+
# Sampling "n" from the geometric distribution and clipping it to
|
336 |
+
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
|
337 |
+
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
|
338 |
+
n = min(np_rng.geometric(0.2), max_ngrams)
|
339 |
+
|
340 |
+
index_set = sum(cand_index_set[n - 1], [])
|
341 |
+
n -= 1
|
342 |
+
# Note(mingdachen):
|
343 |
+
# Repeatedly looking for a candidate that does not exceed the
|
344 |
+
# maximum number of predictions by trying shorter ngrams.
|
345 |
+
while len(masked_lms) + len(index_set) > num_to_predict:
|
346 |
+
if n == 0:
|
347 |
+
break
|
348 |
+
index_set = sum(cand_index_set[n - 1], [])
|
349 |
+
n -= 1
|
350 |
+
# If adding a whole-word mask would exceed the maximum number of
|
351 |
+
# predictions, then just skip this candidate.
|
352 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
353 |
+
continue
|
354 |
+
is_any_index_covered = False
|
355 |
+
for index in index_set:
|
356 |
+
if index in covered_indexes:
|
357 |
+
is_any_index_covered = True
|
358 |
+
break
|
359 |
+
if is_any_index_covered:
|
360 |
+
continue
|
361 |
+
for index in index_set:
|
362 |
+
covered_indexes.add(index)
|
363 |
+
masked_token = None
|
364 |
+
if masking_style == "bert":
|
365 |
+
# 80% of the time, replace with [MASK]
|
366 |
+
if np_rng.random() < 0.8:
|
367 |
+
masked_token = mask_id
|
368 |
+
else:
|
369 |
+
# 10% of the time, keep original
|
370 |
+
if np_rng.random() < 0.5:
|
371 |
+
masked_token = tokens[index]
|
372 |
+
# 10% of the time, replace with random word
|
373 |
+
else:
|
374 |
+
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
|
375 |
+
elif masking_style == 'bert-cn-wwm':
|
376 |
+
# 80% of the time, replace with [MASK]
|
377 |
+
if np_rng.random() < 0.8:
|
378 |
+
masked_token = mask_id
|
379 |
+
else:
|
380 |
+
# 10% of the time, keep original
|
381 |
+
if np_rng.random() < 0.5:
|
382 |
+
# 如果是中文全词mask,去掉tokens里的##
|
383 |
+
token_id = tokens[index]
|
384 |
+
token = tokenizer.convert_ids_to_tokens([token_id])[
|
385 |
+
0]
|
386 |
+
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0:
|
387 |
+
token = token[2:]
|
388 |
+
new_token_id = tokenizer.convert_tokens_to_ids([token])[
|
389 |
+
0]
|
390 |
+
masked_token = new_token_id
|
391 |
+
# 10% of the time, replace with random word
|
392 |
+
else:
|
393 |
+
masked_token = vocab_id_list[np_rng.randint(
|
394 |
+
0, len(vocab_id_list))]
|
395 |
+
elif masking_style == "t5":
|
396 |
+
masked_token = mask_id
|
397 |
+
else:
|
398 |
+
raise ValueError("invalid value of masking style")
|
399 |
+
|
400 |
+
output_tokens[index] = masked_token
|
401 |
+
masked_lms.append(MaskedLmInstance(
|
402 |
+
index=index, label=tokens[index]))
|
403 |
+
|
404 |
+
masked_spans.append(MaskedLmInstance(
|
405 |
+
index=index_set,
|
406 |
+
label=[tokens[index] for index in index_set]))
|
407 |
+
|
408 |
+
assert len(masked_lms) <= num_to_predict
|
409 |
+
np_rng.shuffle(ngram_indexes)
|
410 |
+
|
411 |
+
select_indexes = set()
|
412 |
+
if do_permutation:
|
413 |
+
for cand_index_set in ngram_indexes:
|
414 |
+
if len(select_indexes) >= num_to_predict:
|
415 |
+
break
|
416 |
+
if not cand_index_set:
|
417 |
+
continue
|
418 |
+
# Note(mingdachen):
|
419 |
+
# Skip current piece if they are covered in lm masking or previous ngrams.
|
420 |
+
for index_set in cand_index_set[0]:
|
421 |
+
for index in index_set:
|
422 |
+
if index in covered_indexes or index in select_indexes:
|
423 |
+
continue
|
424 |
+
|
425 |
+
n = np.random.choice(ngrams[:len(cand_index_set)],
|
426 |
+
p=pvals[:len(cand_index_set)] /
|
427 |
+
pvals[:len(cand_index_set)].sum(keepdims=True))
|
428 |
+
index_set = sum(cand_index_set[n - 1], [])
|
429 |
+
n -= 1
|
430 |
+
|
431 |
+
while len(select_indexes) + len(index_set) > num_to_predict:
|
432 |
+
if n == 0:
|
433 |
+
break
|
434 |
+
index_set = sum(cand_index_set[n - 1], [])
|
435 |
+
n -= 1
|
436 |
+
# If adding a whole-word mask would exceed the maximum number of
|
437 |
+
# predictions, then just skip this candidate.
|
438 |
+
if len(select_indexes) + len(index_set) > num_to_predict:
|
439 |
+
continue
|
440 |
+
is_any_index_covered = False
|
441 |
+
for index in index_set:
|
442 |
+
if index in covered_indexes or index in select_indexes:
|
443 |
+
is_any_index_covered = True
|
444 |
+
break
|
445 |
+
if is_any_index_covered:
|
446 |
+
continue
|
447 |
+
for index in index_set:
|
448 |
+
select_indexes.add(index)
|
449 |
+
assert len(select_indexes) <= num_to_predict
|
450 |
+
|
451 |
+
select_indexes = sorted(select_indexes)
|
452 |
+
permute_indexes = list(select_indexes)
|
453 |
+
np_rng.shuffle(permute_indexes)
|
454 |
+
orig_token = list(output_tokens)
|
455 |
+
|
456 |
+
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
457 |
+
output_tokens[src_i] = orig_token[tgt_i]
|
458 |
+
masked_lms.append(MaskedLmInstance(
|
459 |
+
index=src_i, label=orig_token[src_i]))
|
460 |
+
|
461 |
+
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
462 |
+
# Sort the spans by the index of the first span
|
463 |
+
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
|
464 |
+
|
465 |
+
for p in masked_lms:
|
466 |
+
masked_lm_positions.append(p.index)
|
467 |
+
masked_lm_labels.append(p.label)
|
468 |
+
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
|
469 |
+
|
470 |
+
|
471 |
+
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
472 |
+
masked_labels, pad_id, max_seq_length):
|
473 |
+
"""Pad sequences and convert them to numpy."""
|
474 |
+
|
475 |
+
# Some checks.
|
476 |
+
num_tokens = len(tokens)
|
477 |
+
padding_length = max_seq_length - num_tokens
|
478 |
+
assert padding_length >= 0
|
479 |
+
assert len(tokentypes) == num_tokens
|
480 |
+
assert len(masked_positions) == len(masked_labels)
|
481 |
+
|
482 |
+
# Tokens and token types.
|
483 |
+
filler = [pad_id] * padding_length
|
484 |
+
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
485 |
+
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
486 |
+
|
487 |
+
# Padding mask.
|
488 |
+
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
489 |
+
dtype=np.int64)
|
490 |
+
|
491 |
+
# Lables and loss mask.
|
492 |
+
labels = [-1] * max_seq_length
|
493 |
+
loss_mask = [0] * max_seq_length
|
494 |
+
for i in range(len(masked_positions)):
|
495 |
+
assert masked_positions[i] < num_tokens
|
496 |
+
labels[masked_positions[i]] = masked_labels[i]
|
497 |
+
loss_mask[masked_positions[i]] = 1
|
498 |
+
labels_np = np.array(labels, dtype=np.int64)
|
499 |
+
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
500 |
+
|
501 |
+
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
502 |
+
|
503 |
+
|
504 |
+
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
505 |
+
train_valid_test_num_samples,
|
506 |
+
max_seq_length,
|
507 |
+
masked_lm_prob, short_seq_prob, seed,
|
508 |
+
tokenizer,
|
509 |
+
skip_warmup, binary_head=False,
|
510 |
+
max_seq_length_dec=None,
|
511 |
+
dataset_type='standard_bert',
|
512 |
+
zh_tokenizer=None,
|
513 |
+
span=None):
|
514 |
+
|
515 |
+
if len(data_prefix) == 1:
|
516 |
+
return _build_train_valid_test_datasets(data_prefix[0],
|
517 |
+
data_impl, splits_string,
|
518 |
+
train_valid_test_num_samples,
|
519 |
+
max_seq_length, masked_lm_prob,
|
520 |
+
short_seq_prob, seed,
|
521 |
+
skip_warmup,
|
522 |
+
binary_head,
|
523 |
+
max_seq_length_dec,
|
524 |
+
tokenizer,
|
525 |
+
dataset_type=dataset_type,
|
526 |
+
zh_tokenizer=zh_tokenizer,
|
527 |
+
span=span)
|
528 |
+
# Blending dataset.
|
529 |
+
# Parse the values.
|
530 |
+
output = get_datasets_weights_and_num_samples(data_prefix,
|
531 |
+
train_valid_test_num_samples)
|
532 |
+
prefixes, weights, datasets_train_valid_test_num_samples = output
|
533 |
+
|
534 |
+
# Build individual datasets.
|
535 |
+
train_datasets = []
|
536 |
+
valid_datasets = []
|
537 |
+
test_datasets = []
|
538 |
+
for i in range(len(prefixes)):
|
539 |
+
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
540 |
+
prefixes[i], data_impl, splits_string,
|
541 |
+
datasets_train_valid_test_num_samples[i],
|
542 |
+
max_seq_length, masked_lm_prob, short_seq_prob,
|
543 |
+
seed, skip_warmup, binary_head, max_seq_length_dec,
|
544 |
+
tokenizer, dataset_type=dataset_type, zh_tokenizer=zh_tokenizer)
|
545 |
+
if train_ds:
|
546 |
+
train_datasets.append(train_ds)
|
547 |
+
if valid_ds:
|
548 |
+
valid_datasets.append(valid_ds)
|
549 |
+
if test_ds:
|
550 |
+
test_datasets.append(test_ds)
|
551 |
+
|
552 |
+
# Blend.
|
553 |
+
blending_train_dataset = None
|
554 |
+
if train_datasets:
|
555 |
+
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
556 |
+
blending_valid_dataset = None
|
557 |
+
if valid_datasets:
|
558 |
+
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
559 |
+
blending_test_dataset = None
|
560 |
+
if test_datasets:
|
561 |
+
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
562 |
+
|
563 |
+
return (blending_train_dataset, blending_valid_dataset,
|
564 |
+
blending_test_dataset)
|
565 |
+
|
566 |
+
|
567 |
+
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
568 |
+
train_valid_test_num_samples,
|
569 |
+
max_seq_length,
|
570 |
+
masked_lm_prob, short_seq_prob, seed,
|
571 |
+
skip_warmup, binary_head,
|
572 |
+
max_seq_length_dec,
|
573 |
+
tokenizer,
|
574 |
+
dataset_type='standard_bert',
|
575 |
+
zh_tokenizer=None,
|
576 |
+
span=None):
|
577 |
+
|
578 |
+
if dataset_type not in DSET_TYPES:
|
579 |
+
raise ValueError("Invalid dataset_type: ", dataset_type)
|
580 |
+
|
581 |
+
# Indexed dataset.
|
582 |
+
indexed_dataset = get_indexed_dataset_(data_prefix,
|
583 |
+
data_impl,
|
584 |
+
skip_warmup)
|
585 |
+
|
586 |
+
# Get start and end indices of train/valid/train into doc-idx
|
587 |
+
# Note that doc-idx is desinged to be num-docs + 1 so we can
|
588 |
+
# easily iterate over it.
|
589 |
+
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
|
590 |
+
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
591 |
+
|
592 |
+
# Print stats about the splits.
|
593 |
+
print_rank_0(' > dataset split:')
|
594 |
+
|
595 |
+
def print_split_stats(name, index):
|
596 |
+
print_rank_0(' {}:'.format(name))
|
597 |
+
print_rank_0(' document indices in [{}, {}) total of {} '
|
598 |
+
'documents'.format(splits[index], splits[index + 1],
|
599 |
+
splits[index + 1] - splits[index]))
|
600 |
+
start_index = indexed_dataset.doc_idx[splits[index]]
|
601 |
+
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
602 |
+
print_rank_0(' sentence indices in [{}, {}) total of {} '
|
603 |
+
'sentences'.format(start_index, end_index,
|
604 |
+
end_index - start_index))
|
605 |
+
print_split_stats('train', 0)
|
606 |
+
print_split_stats('validation', 1)
|
607 |
+
print_split_stats('test', 2)
|
608 |
+
|
609 |
+
def build_dataset(index, name):
|
610 |
+
from fengshen.data.megatron_dataloader.bert_dataset import BertDataset
|
611 |
+
from fengshen.data.megatron_dataloader.bart_dataset import BartDataset
|
612 |
+
from fengshen.data.megatron_dataloader.cocolm_dataset import COCOLMDataset
|
613 |
+
dataset = None
|
614 |
+
if splits[index + 1] > splits[index]:
|
615 |
+
# Get the pointer to the original doc-idx so we can set it later.
|
616 |
+
doc_idx_ptr = indexed_dataset.get_doc_idx()
|
617 |
+
# Slice the doc-idx
|
618 |
+
start_index = splits[index]
|
619 |
+
# Add +1 so we can index into the dataset to get the upper bound.
|
620 |
+
end_index = splits[index + 1] + 1
|
621 |
+
# New doc_idx view.
|
622 |
+
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
|
623 |
+
# Build the dataset accordingly.
|
624 |
+
kwargs = dict(
|
625 |
+
name=name,
|
626 |
+
data_prefix=data_prefix,
|
627 |
+
num_epochs=None,
|
628 |
+
max_num_samples=train_valid_test_num_samples[index],
|
629 |
+
max_seq_length=max_seq_length,
|
630 |
+
seed=seed,
|
631 |
+
)
|
632 |
+
|
633 |
+
if dataset_type == DSET_TYPE_BERT or dataset_type == DSET_TYPE_BERT_CN_WWM:
|
634 |
+
dataset = BertDataset(
|
635 |
+
indexed_dataset=indexed_dataset,
|
636 |
+
masked_lm_prob=masked_lm_prob,
|
637 |
+
short_seq_prob=short_seq_prob,
|
638 |
+
binary_head=binary_head,
|
639 |
+
# 增加参数区分bert和bert-cn-wwm
|
640 |
+
tokenizer=tokenizer,
|
641 |
+
masking_style='bert' if dataset_type == DSET_TYPE_BERT else 'bert-cn-wwm',
|
642 |
+
**kwargs
|
643 |
+
)
|
644 |
+
elif dataset_type == DSET_TYPE_BART:
|
645 |
+
dataset = BartDataset(
|
646 |
+
indexed_dataset=indexed_dataset,
|
647 |
+
masked_lm_prob=masked_lm_prob,
|
648 |
+
short_seq_prob=short_seq_prob,
|
649 |
+
tokenizer=tokenizer,
|
650 |
+
zh_tokenizer=zh_tokenizer,
|
651 |
+
**kwargs
|
652 |
+
)
|
653 |
+
elif dataset_type == DSET_TYPE_COCOLM:
|
654 |
+
dataset = COCOLMDataset(
|
655 |
+
indexed_dataset=indexed_dataset,
|
656 |
+
masked_lm_prob=masked_lm_prob,
|
657 |
+
short_seq_prob=short_seq_prob,
|
658 |
+
tokenizer=tokenizer,
|
659 |
+
masking_style='bert',
|
660 |
+
span=span,
|
661 |
+
**kwargs
|
662 |
+
)
|
663 |
+
else:
|
664 |
+
raise NotImplementedError(
|
665 |
+
"Dataset type not fully implemented.")
|
666 |
+
|
667 |
+
# Set the original pointer so dataset remains the main dataset.
|
668 |
+
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
669 |
+
# Checks.
|
670 |
+
assert indexed_dataset.doc_idx[0] == 0
|
671 |
+
assert indexed_dataset.doc_idx.shape[0] == \
|
672 |
+
(total_num_of_documents + 1)
|
673 |
+
return dataset
|
674 |
+
|
675 |
+
train_dataset = build_dataset(0, 'train')
|
676 |
+
valid_dataset = build_dataset(1, 'valid')
|
677 |
+
test_dataset = build_dataset(2, 'test')
|
678 |
+
|
679 |
+
return (train_dataset, valid_dataset, test_dataset)
|
680 |
+
|
681 |
+
|
682 |
+
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
683 |
+
|
684 |
+
print_rank_0(' > building dataset index ...')
|
685 |
+
|
686 |
+
start_time = time.time()
|
687 |
+
indexed_dataset = make_indexed_dataset(data_prefix,
|
688 |
+
data_impl,
|
689 |
+
skip_warmup)
|
690 |
+
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
691 |
+
print_rank_0(' > finished creating indexed dataset in {:4f} '
|
692 |
+
'seconds'.format(time.time() - start_time))
|
693 |
+
|
694 |
+
print_rank_0(' > indexed dataset stats:')
|
695 |
+
print_rank_0(' number of documents: {}'.format(
|
696 |
+
indexed_dataset.doc_idx.shape[0] - 1))
|
697 |
+
print_rank_0(' number of sentences: {}'.format(
|
698 |
+
indexed_dataset.sizes.shape[0]))
|
699 |
+
|
700 |
+
return indexed_dataset
|
701 |
+
|
702 |
+
|
703 |
+
def get_train_valid_test_split_(splits_string, size):
|
704 |
+
""" Get dataset splits from comma or '/' separated string list."""
|
705 |
+
|
706 |
+
splits = []
|
707 |
+
if splits_string.find(',') != -1:
|
708 |
+
splits = [float(s) for s in splits_string.split(',')]
|
709 |
+
elif splits_string.find('/') != -1:
|
710 |
+
splits = [float(s) for s in splits_string.split('/')]
|
711 |
+
else:
|
712 |
+
splits = [float(splits_string)]
|
713 |
+
while len(splits) < 3:
|
714 |
+
splits.append(0.)
|
715 |
+
splits = splits[:3]
|
716 |
+
splits_sum = sum(splits)
|
717 |
+
assert splits_sum > 0.0
|
718 |
+
splits = [split / splits_sum for split in splits]
|
719 |
+
splits_index = [0]
|
720 |
+
for index, split in enumerate(splits):
|
721 |
+
splits_index.append(splits_index[index] +
|
722 |
+
int(round(split * float(size))))
|
723 |
+
diff = splits_index[-1] - size
|
724 |
+
for index in range(1, len(splits_index)):
|
725 |
+
splits_index[index] -= diff
|
726 |
+
assert len(splits_index) == 4
|
727 |
+
assert splits_index[-1] == size
|
728 |
+
return splits_index
|
729 |
+
|
730 |
+
|
731 |
+
def get_samples_mapping(indexed_dataset,
|
732 |
+
data_prefix,
|
733 |
+
num_epochs,
|
734 |
+
max_num_samples,
|
735 |
+
max_seq_length,
|
736 |
+
short_seq_prob,
|
737 |
+
seed,
|
738 |
+
name,
|
739 |
+
binary_head):
|
740 |
+
"""Get a list that maps a sample index to a starting
|
741 |
+
sentence index, end sentence index, and length"""
|
742 |
+
|
743 |
+
if not num_epochs:
|
744 |
+
if not max_num_samples:
|
745 |
+
raise ValueError("Need to specify either max_num_samples "
|
746 |
+
"or num_epochs")
|
747 |
+
num_epochs = np.iinfo(np.int32).max - 1
|
748 |
+
if not max_num_samples:
|
749 |
+
max_num_samples = np.iinfo(np.int64).max - 1
|
750 |
+
|
751 |
+
# Filename of the index mapping
|
752 |
+
indexmap_filename = data_prefix
|
753 |
+
indexmap_filename += '_{}_indexmap'.format(name)
|
754 |
+
if num_epochs != (np.iinfo(np.int32).max - 1):
|
755 |
+
indexmap_filename += '_{}ep'.format(num_epochs)
|
756 |
+
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
757 |
+
indexmap_filename += '_{}mns'.format(max_num_samples)
|
758 |
+
indexmap_filename += '_{}msl'.format(max_seq_length)
|
759 |
+
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
|
760 |
+
indexmap_filename += '_{}s'.format(seed)
|
761 |
+
indexmap_filename += '.npy'
|
762 |
+
|
763 |
+
# This should be a barrier but nccl barrier assumes
|
764 |
+
# device_index=rank which is not the case for model
|
765 |
+
# parallel case
|
766 |
+
# ganruyi comment
|
767 |
+
# counts = torch.cuda.LongTensor([1])
|
768 |
+
# torch.distributed.all_reduce(
|
769 |
+
# counts, group=mpu.get_data_parallel_group())
|
770 |
+
# torch.distributed.all_reduce(
|
771 |
+
# counts, group=mpu.get_pipeline_model_parallel_group())
|
772 |
+
# assert counts[0].item() == (
|
773 |
+
# torch.distributed.get_world_size() //
|
774 |
+
# torch.distributed.get_world_size(
|
775 |
+
# group=mpu.get_tensor_model_parallel_group()))
|
776 |
+
|
777 |
+
# Load indexed dataset.
|
778 |
+
print_rank_0(' > loading indexed mapping from {}'.format(
|
779 |
+
indexmap_filename))
|
780 |
+
start_time = time.time()
|
781 |
+
samples_mapping = np.load(
|
782 |
+
indexmap_filename, allow_pickle=True, mmap_mode='r')
|
783 |
+
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
|
784 |
+
time.time() - start_time))
|
785 |
+
print_rank_0(' total number of samples: {}'.format(
|
786 |
+
samples_mapping.shape[0]))
|
787 |
+
|
788 |
+
return samples_mapping
|
fengshen/data/megatron_dataloader/helpers.cpp
ADDED
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
coding=utf-8
|
3 |
+
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
you may not use this file except in compliance with the License.
|
7 |
+
You may obtain a copy of the License at
|
8 |
+
|
9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
Unless required by applicable law or agreed to in writing, software
|
12 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
See the License for the specific language governing permissions and
|
15 |
+
limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
/* Helper methods for fast index mapping builds */
|
19 |
+
|
20 |
+
#include <algorithm>
|
21 |
+
#include <iostream>
|
22 |
+
#include <limits>
|
23 |
+
#include <math.h>
|
24 |
+
#include <stdexcept>
|
25 |
+
#include <pybind11/pybind11.h>
|
26 |
+
#include <pybind11/numpy.h>
|
27 |
+
#include <random>
|
28 |
+
|
29 |
+
namespace py = pybind11;
|
30 |
+
using namespace std;
|
31 |
+
|
32 |
+
const int32_t LONG_SENTENCE_LEN = 512;
|
33 |
+
|
34 |
+
void build_blending_indices(py::array_t<uint8_t> &dataset_index,
|
35 |
+
py::array_t<int64_t> &dataset_sample_index,
|
36 |
+
const py::array_t<double> &weights,
|
37 |
+
const int32_t num_datasets,
|
38 |
+
const int64_t size, const bool verbose)
|
39 |
+
{
|
40 |
+
/* Given multiple datasets and a weighting array, build samples
|
41 |
+
such that it follows those wieghts.*/
|
42 |
+
|
43 |
+
if (verbose)
|
44 |
+
{
|
45 |
+
std::cout << "> building indices for blendable datasets ..." << std::endl;
|
46 |
+
}
|
47 |
+
|
48 |
+
// Get the pointer access without the checks.
|
49 |
+
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
|
50 |
+
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
|
51 |
+
auto weights_ptr = weights.unchecked<1>();
|
52 |
+
|
53 |
+
// Initialize buffer for number of samples used for each dataset.
|
54 |
+
int64_t current_samples[num_datasets];
|
55 |
+
for (int64_t i = 0; i < num_datasets; ++i)
|
56 |
+
{
|
57 |
+
current_samples[i] = 0;
|
58 |
+
}
|
59 |
+
|
60 |
+
// For each sample:
|
61 |
+
for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx)
|
62 |
+
{
|
63 |
+
|
64 |
+
// Determine where the max error in sampling is happening.
|
65 |
+
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
|
66 |
+
int64_t max_error_index = 0;
|
67 |
+
double max_error = weights_ptr[0] * sample_idx_double -
|
68 |
+
static_cast<double>(current_samples[0]);
|
69 |
+
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx)
|
70 |
+
{
|
71 |
+
double error = weights_ptr[dataset_idx] * sample_idx_double -
|
72 |
+
static_cast<double>(current_samples[dataset_idx]);
|
73 |
+
if (error > max_error)
|
74 |
+
{
|
75 |
+
max_error = error;
|
76 |
+
max_error_index = dataset_idx;
|
77 |
+
}
|
78 |
+
}
|
79 |
+
|
80 |
+
// Populate the indices.
|
81 |
+
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
|
82 |
+
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
|
83 |
+
|
84 |
+
// Update the total samples.
|
85 |
+
current_samples[max_error_index] += 1;
|
86 |
+
}
|
87 |
+
|
88 |
+
// print info
|
89 |
+
if (verbose)
|
90 |
+
{
|
91 |
+
std::cout << " > sample ratios:" << std::endl;
|
92 |
+
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx)
|
93 |
+
{
|
94 |
+
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
|
95 |
+
static_cast<double>(size);
|
96 |
+
std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
|
97 |
+
}
|
98 |
+
}
|
99 |
+
}
|
100 |
+
|
101 |
+
py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
|
102 |
+
const py::array_t<int32_t> &doc_idx_,
|
103 |
+
const int32_t seq_length,
|
104 |
+
const int32_t num_epochs,
|
105 |
+
const int64_t tokens_per_epoch)
|
106 |
+
{
|
107 |
+
/* Sample index (sample_idx) is used for gpt2 like dataset for which
|
108 |
+
the documents are flattened and the samples are built based on this
|
109 |
+
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
|
110 |
+
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
|
111 |
+
starting offset in that document.*/
|
112 |
+
|
113 |
+
// Consistency checks.
|
114 |
+
assert(seq_length > 1);
|
115 |
+
assert(num_epochs > 0);
|
116 |
+
assert(tokens_per_epoch > 1);
|
117 |
+
|
118 |
+
// Remove bound checks.
|
119 |
+
auto sizes = sizes_.unchecked<1>();
|
120 |
+
auto doc_idx = doc_idx_.unchecked<1>();
|
121 |
+
|
122 |
+
// Mapping and it's length (1D).
|
123 |
+
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
|
124 |
+
int32_t *sample_idx = new int32_t[2 * (num_samples + 1)];
|
125 |
+
|
126 |
+
cout << " using:" << endl
|
127 |
+
<< std::flush;
|
128 |
+
cout << " number of documents: " << doc_idx_.shape(0) / num_epochs << endl
|
129 |
+
<< std::flush;
|
130 |
+
cout << " number of epochs: " << num_epochs << endl
|
131 |
+
<< std::flush;
|
132 |
+
cout << " sequence length: " << seq_length << endl
|
133 |
+
<< std::flush;
|
134 |
+
cout << " total number of samples: " << num_samples << endl
|
135 |
+
<< std::flush;
|
136 |
+
|
137 |
+
// Index into sample_idx.
|
138 |
+
int64_t sample_index = 0;
|
139 |
+
// Index into doc_idx.
|
140 |
+
int64_t doc_idx_index = 0;
|
141 |
+
// Begining offset for each document.
|
142 |
+
int32_t doc_offset = 0;
|
143 |
+
// Start with first document and no offset.
|
144 |
+
sample_idx[2 * sample_index] = doc_idx_index;
|
145 |
+
sample_idx[2 * sample_index + 1] = doc_offset;
|
146 |
+
++sample_index;
|
147 |
+
|
148 |
+
while (sample_index <= num_samples)
|
149 |
+
{
|
150 |
+
// Start with a fresh sequence.
|
151 |
+
int32_t remaining_seq_length = seq_length + 1;
|
152 |
+
while (remaining_seq_length != 0)
|
153 |
+
{
|
154 |
+
// Get the document length.
|
155 |
+
auto doc_id = doc_idx[doc_idx_index];
|
156 |
+
auto doc_length = sizes[doc_id] - doc_offset;
|
157 |
+
// And add it to the current sequence.
|
158 |
+
remaining_seq_length -= doc_length;
|
159 |
+
// If we have more than a full sequence, adjust offset and set
|
160 |
+
// remaining length to zero so we return from the while loop.
|
161 |
+
// Note that -1 here is for the same reason we have -1 in
|
162 |
+
// `_num_epochs` calculations.
|
163 |
+
if (remaining_seq_length <= 0)
|
164 |
+
{
|
165 |
+
doc_offset += (remaining_seq_length + doc_length - 1);
|
166 |
+
remaining_seq_length = 0;
|
167 |
+
}
|
168 |
+
else
|
169 |
+
{
|
170 |
+
// Otherwise, start from the begining of the next document.
|
171 |
+
++doc_idx_index;
|
172 |
+
doc_offset = 0;
|
173 |
+
}
|
174 |
+
}
|
175 |
+
// Record the sequence.
|
176 |
+
sample_idx[2 * sample_index] = doc_idx_index;
|
177 |
+
sample_idx[2 * sample_index + 1] = doc_offset;
|
178 |
+
++sample_index;
|
179 |
+
}
|
180 |
+
|
181 |
+
// Method to deallocate memory.
|
182 |
+
py::capsule free_when_done(sample_idx, [](void *mem_)
|
183 |
+
{
|
184 |
+
int32_t *mem = reinterpret_cast<int32_t *>(mem_);
|
185 |
+
delete[] mem;
|
186 |
+
});
|
187 |
+
|
188 |
+
// Return the numpy array.
|
189 |
+
const auto byte_size = sizeof(int32_t);
|
190 |
+
return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
|
191 |
+
{2 * byte_size, byte_size}, // C-style contiguous strides
|
192 |
+
sample_idx, // the data pointer
|
193 |
+
free_when_done); // numpy array references
|
194 |
+
}
|
195 |
+
|
196 |
+
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
|
197 |
+
const int32_t max_length,
|
198 |
+
std::mt19937 &rand32_gen)
|
199 |
+
{
|
200 |
+
/* Training sample length. */
|
201 |
+
if (short_seq_ratio == 0)
|
202 |
+
{
|
203 |
+
return max_length;
|
204 |
+
}
|
205 |
+
const auto random_number = rand32_gen();
|
206 |
+
if ((random_number % short_seq_ratio) == 0)
|
207 |
+
{
|
208 |
+
return 2 + random_number % (max_length - 1);
|
209 |
+
}
|
210 |
+
return max_length;
|
211 |
+
}
|
212 |
+
|
213 |
+
template <typename DocIdx>
|
214 |
+
py::array build_mapping_impl(const py::array_t<int64_t> &docs_,
|
215 |
+
const py::array_t<int32_t> &sizes_,
|
216 |
+
const int32_t num_epochs,
|
217 |
+
const uint64_t max_num_samples,
|
218 |
+
const int32_t max_seq_length,
|
219 |
+
const double short_seq_prob,
|
220 |
+
const int32_t seed,
|
221 |
+
const bool verbose,
|
222 |
+
const int32_t min_num_sent)
|
223 |
+
{
|
224 |
+
/* Build a mapping of (start-index, end-index, sequence-length) where
|
225 |
+
start and end index are the indices of the sentences in the sample
|
226 |
+
and sequence-length is the target sequence length.
|
227 |
+
*/
|
228 |
+
|
229 |
+
// Consistency checks.
|
230 |
+
assert(num_epochs > 0);
|
231 |
+
assert(max_seq_length > 1);
|
232 |
+
assert(short_seq_prob >= 0.0);
|
233 |
+
assert(short_seq_prob <= 1.0);
|
234 |
+
assert(seed > 0);
|
235 |
+
|
236 |
+
// Remove bound checks.
|
237 |
+
auto docs = docs_.unchecked<1>();
|
238 |
+
auto sizes = sizes_.unchecked<1>();
|
239 |
+
|
240 |
+
// For efficiency, convert probability to ratio. Note: rand() generates int.
|
241 |
+
int32_t short_seq_ratio = 0;
|
242 |
+
if (short_seq_prob > 0)
|
243 |
+
{
|
244 |
+
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
|
245 |
+
}
|
246 |
+
|
247 |
+
if (verbose)
|
248 |
+
{
|
249 |
+
const auto sent_start_index = docs[0];
|
250 |
+
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
251 |
+
const auto num_sentences = sent_end_index - sent_start_index;
|
252 |
+
cout << " using:" << endl
|
253 |
+
<< std::flush;
|
254 |
+
cout << " number of documents: " << docs_.shape(0) - 1 << endl
|
255 |
+
<< std::flush;
|
256 |
+
cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl
|
257 |
+
<< std::flush;
|
258 |
+
cout << " total number of sentences: " << num_sentences << endl
|
259 |
+
<< std::flush;
|
260 |
+
cout << " number of epochs: " << num_epochs << endl
|
261 |
+
<< std::flush;
|
262 |
+
cout << " maximum number of samples: " << max_num_samples << endl
|
263 |
+
<< std::flush;
|
264 |
+
cout << " maximum sequence length: " << max_seq_length << endl
|
265 |
+
<< std::flush;
|
266 |
+
cout << " short sequence probability: " << short_seq_prob << endl
|
267 |
+
<< std::flush;
|
268 |
+
cout << " short sequence ration (1/prob): " << short_seq_ratio << endl
|
269 |
+
<< std::flush;
|
270 |
+
cout << " seed: " << seed << endl
|
271 |
+
<< std::flush;
|
272 |
+
}
|
273 |
+
|
274 |
+
// Mapping and it's length (1D).
|
275 |
+
int64_t num_samples = -1;
|
276 |
+
DocIdx *maps = NULL;
|
277 |
+
|
278 |
+
// Perform two iterations, in the first iteration get the size
|
279 |
+
// and allocate memory and in the second iteration populate the map.
|
280 |
+
bool second = false;
|
281 |
+
for (int32_t iteration = 0; iteration < 2; ++iteration)
|
282 |
+
{
|
283 |
+
|
284 |
+
// Set the seed so both iterations produce the same results.
|
285 |
+
std::mt19937 rand32_gen(seed);
|
286 |
+
|
287 |
+
// Set the flag on second iteration.
|
288 |
+
second = (iteration == 1);
|
289 |
+
|
290 |
+
// Counters:
|
291 |
+
uint64_t empty_docs = 0;
|
292 |
+
uint64_t one_sent_docs = 0;
|
293 |
+
uint64_t long_sent_docs = 0;
|
294 |
+
|
295 |
+
// Current map index.
|
296 |
+
uint64_t map_index = 0;
|
297 |
+
|
298 |
+
// For each epoch:
|
299 |
+
for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
|
300 |
+
{
|
301 |
+
if (map_index >= max_num_samples)
|
302 |
+
{
|
303 |
+
if (verbose && (!second))
|
304 |
+
{
|
305 |
+
cout << " reached " << max_num_samples << " samples after "
|
306 |
+
<< epoch << " epochs ..." << endl
|
307 |
+
<< std::flush;
|
308 |
+
}
|
309 |
+
break;
|
310 |
+
}
|
311 |
+
// For each document:
|
312 |
+
for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
|
313 |
+
{
|
314 |
+
|
315 |
+
// Document sentences are in [sent_index_first, sent_index_last)
|
316 |
+
const auto sent_index_first = docs[doc];
|
317 |
+
const auto sent_index_last = docs[doc + 1];
|
318 |
+
|
319 |
+
// At the begining of the document previous index is the
|
320 |
+
// start index.
|
321 |
+
auto prev_start_index = sent_index_first;
|
322 |
+
|
323 |
+
// Remaining documents.
|
324 |
+
auto num_remain_sent = sent_index_last - sent_index_first;
|
325 |
+
|
326 |
+
// Some bookkeeping
|
327 |
+
if ((epoch == 0) && (!second))
|
328 |
+
{
|
329 |
+
if (num_remain_sent == 0)
|
330 |
+
{
|
331 |
+
++empty_docs;
|
332 |
+
}
|
333 |
+
if (num_remain_sent == 1)
|
334 |
+
{
|
335 |
+
++one_sent_docs;
|
336 |
+
}
|
337 |
+
}
|
338 |
+
|
339 |
+
// Detect documents with long sentences.
|
340 |
+
bool contains_long_sentence = false;
|
341 |
+
if (num_remain_sent > 1)
|
342 |
+
{
|
343 |
+
for (auto sent_index = sent_index_first;
|
344 |
+
sent_index < sent_index_last; ++sent_index)
|
345 |
+
{
|
346 |
+
if (sizes[sent_index] > LONG_SENTENCE_LEN)
|
347 |
+
{
|
348 |
+
if ((epoch == 0) && (!second))
|
349 |
+
{
|
350 |
+
++long_sent_docs;
|
351 |
+
}
|
352 |
+
contains_long_sentence = true;
|
353 |
+
break;
|
354 |
+
}
|
355 |
+
}
|
356 |
+
}
|
357 |
+
|
358 |
+
// If we have more than two sentences.
|
359 |
+
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
|
360 |
+
{
|
361 |
+
|
362 |
+
// Set values.
|
363 |
+
auto seq_len = int32_t{0};
|
364 |
+
auto num_sent = int32_t{0};
|
365 |
+
auto target_seq_len = get_target_sample_len(short_seq_ratio,
|
366 |
+
max_seq_length,
|
367 |
+
rand32_gen);
|
368 |
+
|
369 |
+
// Loop through sentences.
|
370 |
+
for (auto sent_index = sent_index_first;
|
371 |
+
sent_index < sent_index_last; ++sent_index)
|
372 |
+
{
|
373 |
+
|
374 |
+
// Add the size and number of sentences.
|
375 |
+
seq_len += sizes[sent_index];
|
376 |
+
++num_sent;
|
377 |
+
--num_remain_sent;
|
378 |
+
|
379 |
+
// If we have reached the target length.
|
380 |
+
// and if not only one sentence is left in the document.
|
381 |
+
// and if we have at least two sentneces.
|
382 |
+
// and if we have reached end of the document.
|
383 |
+
if (((seq_len >= target_seq_len) &&
|
384 |
+
(num_remain_sent > 1) &&
|
385 |
+
(num_sent >= min_num_sent)) ||
|
386 |
+
(num_remain_sent == 0))
|
387 |
+
{
|
388 |
+
|
389 |
+
// Check for overflow.
|
390 |
+
if ((3 * map_index + 2) >
|
391 |
+
std::numeric_limits<int64_t>::max())
|
392 |
+
{
|
393 |
+
cout << "number of samples exceeded maximum "
|
394 |
+
<< "allowed by type int64: "
|
395 |
+
<< std::numeric_limits<int64_t>::max()
|
396 |
+
<< endl;
|
397 |
+
throw std::overflow_error("Number of samples");
|
398 |
+
}
|
399 |
+
|
400 |
+
// Populate the map.
|
401 |
+
if (second)
|
402 |
+
{
|
403 |
+
const auto map_index_0 = 3 * map_index;
|
404 |
+
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
405 |
+
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
406 |
+
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
|
407 |
+
}
|
408 |
+
|
409 |
+
// Update indices / counters.
|
410 |
+
++map_index;
|
411 |
+
prev_start_index = sent_index + 1;
|
412 |
+
target_seq_len = get_target_sample_len(short_seq_ratio,
|
413 |
+
max_seq_length,
|
414 |
+
rand32_gen);
|
415 |
+
seq_len = 0;
|
416 |
+
num_sent = 0;
|
417 |
+
}
|
418 |
+
|
419 |
+
} // for (auto sent_index=sent_index_first; ...
|
420 |
+
} // if (num_remain_sent > 1) {
|
421 |
+
} // for (int doc=0; doc < num_docs; ++doc) {
|
422 |
+
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
423 |
+
|
424 |
+
if (!second)
|
425 |
+
{
|
426 |
+
if (verbose)
|
427 |
+
{
|
428 |
+
cout << " number of empty documents: " << empty_docs << endl
|
429 |
+
<< std::flush;
|
430 |
+
cout << " number of documents with one sentence: " << one_sent_docs << endl
|
431 |
+
<< std::flush;
|
432 |
+
cout << " number of documents with long sentences: " << long_sent_docs << endl
|
433 |
+
<< std::flush;
|
434 |
+
cout << " will create mapping for " << map_index << " samples" << endl
|
435 |
+
<< std::flush;
|
436 |
+
}
|
437 |
+
assert(maps == NULL);
|
438 |
+
assert(num_samples < 0);
|
439 |
+
maps = new DocIdx[3 * map_index];
|
440 |
+
num_samples = static_cast<int64_t>(map_index);
|
441 |
+
}
|
442 |
+
|
443 |
+
} // for (int iteration=0; iteration < 2; ++iteration) {
|
444 |
+
|
445 |
+
// Shuffle.
|
446 |
+
// We need a 64 bit random number generator as we might have more
|
447 |
+
// than 2 billion samples.
|
448 |
+
std::mt19937_64 rand64_gen(seed + 1);
|
449 |
+
for (auto i = (num_samples - 1); i > 0; --i)
|
450 |
+
{
|
451 |
+
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
452 |
+
const auto i0 = 3 * i;
|
453 |
+
const auto j0 = 3 * j;
|
454 |
+
// Swap values.
|
455 |
+
swap(maps[i0], maps[j0]);
|
456 |
+
swap(maps[i0 + 1], maps[j0 + 1]);
|
457 |
+
swap(maps[i0 + 2], maps[j0 + 2]);
|
458 |
+
}
|
459 |
+
|
460 |
+
// Method to deallocate memory.
|
461 |
+
py::capsule free_when_done(maps, [](void *mem_)
|
462 |
+
{
|
463 |
+
DocIdx *mem = reinterpret_cast<DocIdx *>(mem_);
|
464 |
+
delete[] mem;
|
465 |
+
});
|
466 |
+
|
467 |
+
// Return the numpy array.
|
468 |
+
const auto byte_size = sizeof(DocIdx);
|
469 |
+
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
|
470 |
+
{3 * byte_size, byte_size}, // C-style contiguous strides
|
471 |
+
maps, // the data pointer
|
472 |
+
free_when_done); // numpy array references
|
473 |
+
}
|
474 |
+
|
475 |
+
py::array build_mapping(const py::array_t<int64_t> &docs_,
|
476 |
+
const py::array_t<int> &sizes_,
|
477 |
+
const int num_epochs,
|
478 |
+
const uint64_t max_num_samples,
|
479 |
+
const int max_seq_length,
|
480 |
+
const double short_seq_prob,
|
481 |
+
const int seed,
|
482 |
+
const bool verbose,
|
483 |
+
const int32_t min_num_sent)
|
484 |
+
{
|
485 |
+
|
486 |
+
if (sizes_.size() > std::numeric_limits<uint32_t>::max())
|
487 |
+
{
|
488 |
+
if (verbose)
|
489 |
+
{
|
490 |
+
cout << " using uint64 for data mapping..." << endl
|
491 |
+
<< std::flush;
|
492 |
+
}
|
493 |
+
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
|
494 |
+
max_num_samples, max_seq_length,
|
495 |
+
short_seq_prob, seed, verbose,
|
496 |
+
min_num_sent);
|
497 |
+
}
|
498 |
+
else
|
499 |
+
{
|
500 |
+
if (verbose)
|
501 |
+
{
|
502 |
+
cout << " using uint32 for data mapping..." << endl
|
503 |
+
<< std::flush;
|
504 |
+
}
|
505 |
+
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
|
506 |
+
max_num_samples, max_seq_length,
|
507 |
+
short_seq_prob, seed, verbose,
|
508 |
+
min_num_sent);
|
509 |
+
}
|
510 |
+
}
|
511 |
+
|
512 |
+
template <typename DocIdx>
|
513 |
+
py::array build_blocks_mapping_impl(const py::array_t<int64_t> &docs_,
|
514 |
+
const py::array_t<int32_t> &sizes_,
|
515 |
+
const py::array_t<int32_t> &titles_sizes_,
|
516 |
+
const int32_t num_epochs,
|
517 |
+
const uint64_t max_num_samples,
|
518 |
+
const int32_t max_seq_length,
|
519 |
+
const int32_t seed,
|
520 |
+
const bool verbose,
|
521 |
+
const bool use_one_sent_blocks)
|
522 |
+
{
|
523 |
+
/* Build a mapping of (start-index, end-index, sequence-length) where
|
524 |
+
start and end index are the indices of the sentences in the sample
|
525 |
+
and sequence-length is the target sequence length.
|
526 |
+
*/
|
527 |
+
|
528 |
+
// Consistency checks.
|
529 |
+
assert(num_epochs > 0);
|
530 |
+
assert(max_seq_length > 1);
|
531 |
+
assert(seed > 0);
|
532 |
+
|
533 |
+
// Remove bound checks.
|
534 |
+
auto docs = docs_.unchecked<1>();
|
535 |
+
auto sizes = sizes_.unchecked<1>();
|
536 |
+
auto titles_sizes = titles_sizes_.unchecked<1>();
|
537 |
+
|
538 |
+
if (verbose)
|
539 |
+
{
|
540 |
+
const auto sent_start_index = docs[0];
|
541 |
+
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
542 |
+
const auto num_sentences = sent_end_index - sent_start_index;
|
543 |
+
cout << " using:" << endl
|
544 |
+
<< std::flush;
|
545 |
+
cout << " number of documents: " << docs_.shape(0) - 1 << endl
|
546 |
+
<< std::flush;
|
547 |
+
cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl
|
548 |
+
<< std::flush;
|
549 |
+
cout << " total number of sentences: " << num_sentences << endl
|
550 |
+
<< std::flush;
|
551 |
+
cout << " number of epochs: " << num_epochs << endl
|
552 |
+
<< std::flush;
|
553 |
+
cout << " maximum number of samples: " << max_num_samples << endl
|
554 |
+
<< std::flush;
|
555 |
+
cout << " maximum sequence length: " << max_seq_length << endl
|
556 |
+
<< std::flush;
|
557 |
+
cout << " seed: " << seed << endl
|
558 |
+
<< std::flush;
|
559 |
+
}
|
560 |
+
|
561 |
+
// Mapping and its length (1D).
|
562 |
+
int64_t num_samples = -1;
|
563 |
+
DocIdx *maps = NULL;
|
564 |
+
|
565 |
+
// Acceptable number of sentences per block.
|
566 |
+
int min_num_sent = 2;
|
567 |
+
if (use_one_sent_blocks)
|
568 |
+
{
|
569 |
+
min_num_sent = 1;
|
570 |
+
}
|
571 |
+
|
572 |
+
// Perform two iterations, in the first iteration get the size
|
573 |
+
// and allocate memory and in the second iteration populate the map.
|
574 |
+
bool second = false;
|
575 |
+
for (int32_t iteration = 0; iteration < 2; ++iteration)
|
576 |
+
{
|
577 |
+
|
578 |
+
// Set the flag on second iteration.
|
579 |
+
second = (iteration == 1);
|
580 |
+
|
581 |
+
// Current map index.
|
582 |
+
uint64_t map_index = 0;
|
583 |
+
|
584 |
+
uint64_t empty_docs = 0;
|
585 |
+
uint64_t one_sent_docs = 0;
|
586 |
+
uint64_t long_sent_docs = 0;
|
587 |
+
// For each epoch:
|
588 |
+
for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
|
589 |
+
{
|
590 |
+
// assign every block a unique id
|
591 |
+
int32_t block_id = 0;
|
592 |
+
|
593 |
+
if (map_index >= max_num_samples)
|
594 |
+
{
|
595 |
+
if (verbose && (!second))
|
596 |
+
{
|
597 |
+
cout << " reached " << max_num_samples << " samples after "
|
598 |
+
<< epoch << " epochs ..." << endl
|
599 |
+
<< std::flush;
|
600 |
+
}
|
601 |
+
break;
|
602 |
+
}
|
603 |
+
// For each document:
|
604 |
+
for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
|
605 |
+
{
|
606 |
+
|
607 |
+
// Document sentences are in [sent_index_first, sent_index_last)
|
608 |
+
const auto sent_index_first = docs[doc];
|
609 |
+
const auto sent_index_last = docs[doc + 1];
|
610 |
+
const auto target_seq_len = max_seq_length - titles_sizes[doc];
|
611 |
+
|
612 |
+
// At the begining of the document previous index is the
|
613 |
+
// start index.
|
614 |
+
auto prev_start_index = sent_index_first;
|
615 |
+
|
616 |
+
// Remaining documents.
|
617 |
+
auto num_remain_sent = sent_index_last - sent_index_first;
|
618 |
+
|
619 |
+
// Some bookkeeping
|
620 |
+
if ((epoch == 0) && (!second))
|
621 |
+
{
|
622 |
+
if (num_remain_sent == 0)
|
623 |
+
{
|
624 |
+
++empty_docs;
|
625 |
+
}
|
626 |
+
if (num_remain_sent == 1)
|
627 |
+
{
|
628 |
+
++one_sent_docs;
|
629 |
+
}
|
630 |
+
}
|
631 |
+
// Detect documents with long sentences.
|
632 |
+
bool contains_long_sentence = false;
|
633 |
+
if (num_remain_sent >= min_num_sent)
|
634 |
+
{
|
635 |
+
for (auto sent_index = sent_index_first;
|
636 |
+
sent_index < sent_index_last; ++sent_index)
|
637 |
+
{
|
638 |
+
if (sizes[sent_index] > LONG_SENTENCE_LEN)
|
639 |
+
{
|
640 |
+
if ((epoch == 0) && (!second))
|
641 |
+
{
|
642 |
+
++long_sent_docs;
|
643 |
+
}
|
644 |
+
contains_long_sentence = true;
|
645 |
+
break;
|
646 |
+
}
|
647 |
+
}
|
648 |
+
}
|
649 |
+
// If we have enough sentences and no long sentences.
|
650 |
+
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
|
651 |
+
{
|
652 |
+
|
653 |
+
// Set values.
|
654 |
+
auto seq_len = int32_t{0};
|
655 |
+
auto num_sent = int32_t{0};
|
656 |
+
|
657 |
+
// Loop through sentences.
|
658 |
+
for (auto sent_index = sent_index_first;
|
659 |
+
sent_index < sent_index_last; ++sent_index)
|
660 |
+
{
|
661 |
+
|
662 |
+
// Add the size and number of sentences.
|
663 |
+
seq_len += sizes[sent_index];
|
664 |
+
++num_sent;
|
665 |
+
--num_remain_sent;
|
666 |
+
|
667 |
+
// If we have reached the target length.
|
668 |
+
// and there are an acceptable number of sentences left
|
669 |
+
// and if we have at least the minimum number of sentences.
|
670 |
+
// or if we have reached end of the document.
|
671 |
+
if (((seq_len >= target_seq_len) &&
|
672 |
+
(num_remain_sent >= min_num_sent) &&
|
673 |
+
(num_sent >= min_num_sent)) ||
|
674 |
+
(num_remain_sent == 0))
|
675 |
+
{
|
676 |
+
|
677 |
+
// Populate the map.
|
678 |
+
if (second)
|
679 |
+
{
|
680 |
+
const auto map_index_0 = 4 * map_index;
|
681 |
+
// Each sample has 4 items: the starting sentence index, ending sentence index,
|
682 |
+
// the index of the document from which the block comes (used for fetching titles)
|
683 |
+
// and the unique id of the block (used for creating block indexes)
|
684 |
+
|
685 |
+
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
686 |
+
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
687 |
+
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
|
688 |
+
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
|
689 |
+
}
|
690 |
+
|
691 |
+
// Update indices / counters.
|
692 |
+
++map_index;
|
693 |
+
++block_id;
|
694 |
+
prev_start_index = sent_index + 1;
|
695 |
+
seq_len = 0;
|
696 |
+
num_sent = 0;
|
697 |
+
}
|
698 |
+
} // for (auto sent_index=sent_index_first; ...
|
699 |
+
} // if (num_remain_sent > 1) {
|
700 |
+
} // for (int doc=0; doc < num_docs; ++doc) {
|
701 |
+
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
702 |
+
|
703 |
+
if (!second)
|
704 |
+
{
|
705 |
+
if (verbose)
|
706 |
+
{
|
707 |
+
cout << " number of empty documents: " << empty_docs << endl
|
708 |
+
<< std::flush;
|
709 |
+
cout << " number of documents with one sentence: " << one_sent_docs << endl
|
710 |
+
<< std::flush;
|
711 |
+
cout << " number of documents with long sentences: " << long_sent_docs << endl
|
712 |
+
<< std::flush;
|
713 |
+
cout << " will create mapping for " << map_index << " samples" << endl
|
714 |
+
<< std::flush;
|
715 |
+
}
|
716 |
+
assert(maps == NULL);
|
717 |
+
assert(num_samples < 0);
|
718 |
+
maps = new DocIdx[4 * map_index];
|
719 |
+
num_samples = static_cast<int64_t>(map_index);
|
720 |
+
}
|
721 |
+
|
722 |
+
} // for (int iteration=0; iteration < 2; ++iteration) {
|
723 |
+
|
724 |
+
// Shuffle.
|
725 |
+
// We need a 64 bit random number generator as we might have more
|
726 |
+
// than 2 billion samples.
|
727 |
+
std::mt19937_64 rand64_gen(seed + 1);
|
728 |
+
for (auto i = (num_samples - 1); i > 0; --i)
|
729 |
+
{
|
730 |
+
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
731 |
+
const auto i0 = 4 * i;
|
732 |
+
const auto j0 = 4 * j;
|
733 |
+
// Swap values.
|
734 |
+
swap(maps[i0], maps[j0]);
|
735 |
+
swap(maps[i0 + 1], maps[j0 + 1]);
|
736 |
+
swap(maps[i0 + 2], maps[j0 + 2]);
|
737 |
+
swap(maps[i0 + 3], maps[j0 + 3]);
|
738 |
+
}
|
739 |
+
|
740 |
+
// Method to deallocate memory.
|
741 |
+
py::capsule free_when_done(maps, [](void *mem_)
|
742 |
+
{
|
743 |
+
DocIdx *mem = reinterpret_cast<DocIdx *>(mem_);
|
744 |
+
delete[] mem;
|
745 |
+
});
|
746 |
+
|
747 |
+
// Return the numpy array.
|
748 |
+
const auto byte_size = sizeof(DocIdx);
|
749 |
+
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
|
750 |
+
{4 * byte_size, byte_size}, // C-style contiguous strides
|
751 |
+
maps, // the data pointer
|
752 |
+
free_when_done); // numpy array references
|
753 |
+
}
|
754 |
+
|
755 |
+
py::array build_blocks_mapping(const py::array_t<int64_t> &docs_,
|
756 |
+
const py::array_t<int> &sizes_,
|
757 |
+
const py::array_t<int> &titles_sizes_,
|
758 |
+
const int num_epochs,
|
759 |
+
const uint64_t max_num_samples,
|
760 |
+
const int max_seq_length,
|
761 |
+
const int seed,
|
762 |
+
const bool verbose,
|
763 |
+
const bool use_one_sent_blocks)
|
764 |
+
{
|
765 |
+
|
766 |
+
if (sizes_.size() > std::numeric_limits<uint32_t>::max())
|
767 |
+
{
|
768 |
+
if (verbose)
|
769 |
+
{
|
770 |
+
cout << " using uint64 for data mapping..." << endl
|
771 |
+
<< std::flush;
|
772 |
+
}
|
773 |
+
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
|
774 |
+
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
775 |
+
}
|
776 |
+
else
|
777 |
+
{
|
778 |
+
if (verbose)
|
779 |
+
{
|
780 |
+
cout << " using uint32 for data mapping..." << endl
|
781 |
+
<< std::flush;
|
782 |
+
}
|
783 |
+
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
|
784 |
+
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
785 |
+
}
|
786 |
+
}
|
787 |
+
|
788 |
+
PYBIND11_MODULE(helpers, m)
|
789 |
+
{
|
790 |
+
m.def("build_mapping", &build_mapping);
|
791 |
+
m.def("build_blocks_mapping", &build_blocks_mapping);
|
792 |
+
m.def("build_sample_idx", &build_sample_idx);
|
793 |
+
m.def("build_blending_indices", &build_blending_indices);
|
794 |
+
}
|
fengshen/data/megatron_dataloader/indexed_dataset.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
# copied from fairseq/fairseq/data/indexed_dataset.py
|
8 |
+
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
9 |
+
# other slight modifications to remove fairseq dependencies
|
10 |
+
# Added document index to index file and made it accessible.
|
11 |
+
# An empty sentence no longer separates documents.
|
12 |
+
|
13 |
+
from functools import lru_cache
|
14 |
+
import os
|
15 |
+
import shutil
|
16 |
+
import struct
|
17 |
+
from itertools import accumulate
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from fengshen.data.megatron_dataloader.utils import print_rank_0
|
22 |
+
|
23 |
+
|
24 |
+
def __best_fitting_dtype(vocab_size=None):
|
25 |
+
if vocab_size is not None and vocab_size < 65500:
|
26 |
+
return np.uint16
|
27 |
+
else:
|
28 |
+
return np.int32
|
29 |
+
|
30 |
+
|
31 |
+
def get_available_dataset_impl():
|
32 |
+
return ['lazy', 'cached', 'mmap']
|
33 |
+
|
34 |
+
|
35 |
+
def infer_dataset_impl(path):
|
36 |
+
if IndexedDataset.exists(path):
|
37 |
+
with open(index_file_path(path), 'rb') as f:
|
38 |
+
magic = f.read(8)
|
39 |
+
if magic == IndexedDataset._HDR_MAGIC:
|
40 |
+
return 'cached'
|
41 |
+
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
42 |
+
return 'mmap'
|
43 |
+
else:
|
44 |
+
return None
|
45 |
+
else:
|
46 |
+
print(f"Dataset does not exist: {path}")
|
47 |
+
print("Path should be a basename that both .idx and "
|
48 |
+
".bin can be appended to get full filenames.")
|
49 |
+
return None
|
50 |
+
|
51 |
+
|
52 |
+
def make_builder(out_file, impl, vocab_size=None):
|
53 |
+
if impl == 'mmap':
|
54 |
+
return MMapIndexedDatasetBuilder(out_file,
|
55 |
+
dtype=__best_fitting_dtype(vocab_size))
|
56 |
+
else:
|
57 |
+
return IndexedDatasetBuilder(out_file)
|
58 |
+
|
59 |
+
|
60 |
+
def make_dataset(path, impl, skip_warmup=False):
|
61 |
+
if not IndexedDataset.exists(path):
|
62 |
+
print(f"Dataset does not exist: {path}")
|
63 |
+
print("Path should be a basename that both .idx "
|
64 |
+
"and .bin can be appended to get full filenames.")
|
65 |
+
return None
|
66 |
+
if impl == 'infer':
|
67 |
+
impl = infer_dataset_impl(path)
|
68 |
+
if impl == 'lazy' and IndexedDataset.exists(path):
|
69 |
+
return IndexedDataset(path)
|
70 |
+
elif impl == 'cached' and IndexedDataset.exists(path):
|
71 |
+
return IndexedCachedDataset(path)
|
72 |
+
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
|
73 |
+
return MMapIndexedDataset(path, skip_warmup)
|
74 |
+
print(f"Unknown dataset implementation: {impl}")
|
75 |
+
return None
|
76 |
+
|
77 |
+
|
78 |
+
def dataset_exists(path, impl):
|
79 |
+
if impl == 'mmap':
|
80 |
+
return MMapIndexedDataset.exists(path)
|
81 |
+
else:
|
82 |
+
return IndexedDataset.exists(path)
|
83 |
+
|
84 |
+
|
85 |
+
def read_longs(f, n):
|
86 |
+
a = np.empty(n, dtype=np.int64)
|
87 |
+
f.readinto(a)
|
88 |
+
return a
|
89 |
+
|
90 |
+
|
91 |
+
def write_longs(f, a):
|
92 |
+
f.write(np.array(a, dtype=np.int64))
|
93 |
+
|
94 |
+
|
95 |
+
dtypes = {
|
96 |
+
1: np.uint8,
|
97 |
+
2: np.int8,
|
98 |
+
3: np.int16,
|
99 |
+
4: np.int32,
|
100 |
+
5: np.int64,
|
101 |
+
6: np.float,
|
102 |
+
7: np.double,
|
103 |
+
8: np.uint16
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
def code(dtype):
|
108 |
+
for k in dtypes.keys():
|
109 |
+
if dtypes[k] == dtype:
|
110 |
+
return k
|
111 |
+
raise ValueError(dtype)
|
112 |
+
|
113 |
+
|
114 |
+
def index_file_path(prefix_path):
|
115 |
+
return prefix_path + '.idx'
|
116 |
+
|
117 |
+
|
118 |
+
def data_file_path(prefix_path):
|
119 |
+
return prefix_path + '.bin'
|
120 |
+
|
121 |
+
|
122 |
+
def create_doc_idx(sizes):
|
123 |
+
doc_idx = [0]
|
124 |
+
for i, s in enumerate(sizes):
|
125 |
+
if s == 0:
|
126 |
+
doc_idx.append(i + 1)
|
127 |
+
return doc_idx
|
128 |
+
|
129 |
+
|
130 |
+
class IndexedDataset(torch.utils.data.Dataset):
|
131 |
+
"""Loader for IndexedDataset"""
|
132 |
+
_HDR_MAGIC = b'TNTIDX\x00\x00'
|
133 |
+
|
134 |
+
def __init__(self, path):
|
135 |
+
super().__init__()
|
136 |
+
self.path = path
|
137 |
+
self.data_file = None
|
138 |
+
self.read_index(path)
|
139 |
+
|
140 |
+
def read_index(self, path):
|
141 |
+
with open(index_file_path(path), 'rb') as f:
|
142 |
+
magic = f.read(8)
|
143 |
+
assert magic == self._HDR_MAGIC, (
|
144 |
+
'Index file doesn\'t match expected format. '
|
145 |
+
'Make sure that --dataset-impl is configured properly.'
|
146 |
+
)
|
147 |
+
version = f.read(8)
|
148 |
+
assert struct.unpack('<Q', version) == (1,)
|
149 |
+
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
150 |
+
self.dtype = dtypes[code]
|
151 |
+
self._len, self.s = struct.unpack('<QQ', f.read(16))
|
152 |
+
self.doc_count = struct.unpack('<Q', f.read(8))
|
153 |
+
self.dim_offsets = read_longs(f, self._len + 1)
|
154 |
+
self.data_offsets = read_longs(f, self._len + 1)
|
155 |
+
self.sizes = read_longs(f, self.s)
|
156 |
+
self.doc_idx = read_longs(f, self.doc_count)
|
157 |
+
|
158 |
+
def read_data(self, path):
|
159 |
+
self.data_file = open(data_file_path(path), 'rb', buffering=0)
|
160 |
+
|
161 |
+
def check_index(self, i):
|
162 |
+
if i < 0 or i >= self._len:
|
163 |
+
raise IndexError('index out of range')
|
164 |
+
|
165 |
+
def __del__(self):
|
166 |
+
if self.data_file:
|
167 |
+
self.data_file.close()
|
168 |
+
|
169 |
+
# @lru_cache(maxsize=8)
|
170 |
+
def __getitem__(self, idx):
|
171 |
+
if not self.data_file:
|
172 |
+
self.read_data(self.path)
|
173 |
+
if isinstance(idx, int):
|
174 |
+
i = idx
|
175 |
+
self.check_index(i)
|
176 |
+
tensor_size = self.sizes[
|
177 |
+
self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
178 |
+
a = np.empty(tensor_size, dtype=self.dtype)
|
179 |
+
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
180 |
+
self.data_file.readinto(a)
|
181 |
+
return a
|
182 |
+
elif isinstance(idx, slice):
|
183 |
+
start, stop, step = idx.indices(len(self))
|
184 |
+
if step != 1:
|
185 |
+
raise ValueError(
|
186 |
+
"Slices into indexed_dataset must be contiguous")
|
187 |
+
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
|
188 |
+
size = sum(sizes)
|
189 |
+
a = np.empty(size, dtype=self.dtype)
|
190 |
+
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
191 |
+
self.data_file.readinto(a)
|
192 |
+
offsets = list(accumulate(sizes))
|
193 |
+
sents = np.split(a, offsets[:-1])
|
194 |
+
return sents
|
195 |
+
|
196 |
+
def __len__(self):
|
197 |
+
return self._len
|
198 |
+
|
199 |
+
def num_tokens(self, index):
|
200 |
+
return self.sizes[index]
|
201 |
+
|
202 |
+
def size(self, index):
|
203 |
+
return self.sizes[index]
|
204 |
+
|
205 |
+
@staticmethod
|
206 |
+
def exists(path):
|
207 |
+
return (
|
208 |
+
os.path.exists(index_file_path(path)) and os.path.exists(
|
209 |
+
data_file_path(path))
|
210 |
+
)
|
211 |
+
|
212 |
+
@property
|
213 |
+
def supports_prefetch(self):
|
214 |
+
return False # avoid prefetching to save memory
|
215 |
+
|
216 |
+
|
217 |
+
class IndexedCachedDataset(IndexedDataset):
|
218 |
+
|
219 |
+
def __init__(self, path):
|
220 |
+
super().__init__(path)
|
221 |
+
self.cache = None
|
222 |
+
self.cache_index = {}
|
223 |
+
|
224 |
+
@property
|
225 |
+
def supports_prefetch(self):
|
226 |
+
return True
|
227 |
+
|
228 |
+
def prefetch(self, indices):
|
229 |
+
if all(i in self.cache_index for i in indices):
|
230 |
+
return
|
231 |
+
if not self.data_file:
|
232 |
+
self.read_data(self.path)
|
233 |
+
indices = sorted(set(indices))
|
234 |
+
total_size = 0
|
235 |
+
for i in indices:
|
236 |
+
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
237 |
+
self.cache = np.empty(total_size, dtype=self.dtype)
|
238 |
+
ptx = 0
|
239 |
+
self.cache_index.clear()
|
240 |
+
for i in indices:
|
241 |
+
self.cache_index[i] = ptx
|
242 |
+
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
243 |
+
a = self.cache[ptx: ptx + size]
|
244 |
+
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
245 |
+
self.data_file.readinto(a)
|
246 |
+
ptx += size
|
247 |
+
if self.data_file:
|
248 |
+
# close and delete data file after prefetch so we can pickle
|
249 |
+
self.data_file.close()
|
250 |
+
self.data_file = None
|
251 |
+
|
252 |
+
# @lru_cache(maxsize=8)
|
253 |
+
def __getitem__(self, idx):
|
254 |
+
if isinstance(idx, int):
|
255 |
+
i = idx
|
256 |
+
self.check_index(i)
|
257 |
+
tensor_size = self.sizes[
|
258 |
+
self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
259 |
+
a = np.empty(tensor_size, dtype=self.dtype)
|
260 |
+
ptx = self.cache_index[i]
|
261 |
+
np.copyto(a, self.cache[ptx: ptx + a.size])
|
262 |
+
return a
|
263 |
+
elif isinstance(idx, slice):
|
264 |
+
# Hack just to make this work, can optimizer later if necessary
|
265 |
+
sents = []
|
266 |
+
for i in range(*idx.indices(len(self))):
|
267 |
+
sents.append(self[i])
|
268 |
+
return sents
|
269 |
+
|
270 |
+
|
271 |
+
class IndexedDatasetBuilder(object):
|
272 |
+
element_sizes = {
|
273 |
+
np.uint8: 1,
|
274 |
+
np.int8: 1,
|
275 |
+
np.int16: 2,
|
276 |
+
np.int32: 4,
|
277 |
+
np.int64: 8,
|
278 |
+
np.float: 4,
|
279 |
+
np.double: 8
|
280 |
+
}
|
281 |
+
|
282 |
+
def __init__(self, out_file, dtype=np.int32):
|
283 |
+
self.out_file = open(out_file, 'wb')
|
284 |
+
self.dtype = dtype
|
285 |
+
self.data_offsets = [0]
|
286 |
+
self.dim_offsets = [0]
|
287 |
+
self.sizes = []
|
288 |
+
self.element_size = self.element_sizes[self.dtype]
|
289 |
+
self.doc_idx = [0]
|
290 |
+
|
291 |
+
def add_item(self, tensor):
|
292 |
+
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
|
293 |
+
self.data_offsets.append(
|
294 |
+
self.data_offsets[-1] + bytes / self.element_size)
|
295 |
+
for s in tensor.size():
|
296 |
+
self.sizes.append(s)
|
297 |
+
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
298 |
+
|
299 |
+
def end_document(self):
|
300 |
+
self.doc_idx.append(len(self.sizes))
|
301 |
+
|
302 |
+
def merge_file_(self, another_file):
|
303 |
+
index = IndexedDataset(another_file)
|
304 |
+
assert index.dtype == self.dtype
|
305 |
+
|
306 |
+
begin = self.data_offsets[-1]
|
307 |
+
for offset in index.data_offsets[1:]:
|
308 |
+
self.data_offsets.append(begin + offset)
|
309 |
+
self.sizes.extend(index.sizes)
|
310 |
+
begin = self.dim_offsets[-1]
|
311 |
+
for dim_offset in index.dim_offsets[1:]:
|
312 |
+
self.dim_offsets.append(begin + dim_offset)
|
313 |
+
|
314 |
+
with open(data_file_path(another_file), 'rb') as f:
|
315 |
+
while True:
|
316 |
+
data = f.read(1024)
|
317 |
+
if data:
|
318 |
+
self.out_file.write(data)
|
319 |
+
else:
|
320 |
+
break
|
321 |
+
|
322 |
+
def finalize(self, index_file):
|
323 |
+
self.out_file.close()
|
324 |
+
index = open(index_file, 'wb')
|
325 |
+
index.write(b'TNTIDX\x00\x00')
|
326 |
+
index.write(struct.pack('<Q', 1))
|
327 |
+
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
|
328 |
+
index.write(struct.pack('<QQ', len(
|
329 |
+
self.data_offsets) - 1, len(self.sizes)))
|
330 |
+
index.write(struct.pack('<Q', len(self.doc_idx)))
|
331 |
+
write_longs(index, self.dim_offsets)
|
332 |
+
write_longs(index, self.data_offsets)
|
333 |
+
write_longs(index, self.sizes)
|
334 |
+
write_longs(index, self.doc_idx)
|
335 |
+
index.close()
|
336 |
+
|
337 |
+
|
338 |
+
def _warmup_mmap_file(path):
|
339 |
+
with open(path, 'rb') as stream:
|
340 |
+
while stream.read(100 * 1024 * 1024):
|
341 |
+
pass
|
342 |
+
|
343 |
+
|
344 |
+
class MMapIndexedDataset(torch.utils.data.Dataset):
|
345 |
+
class Index(object):
|
346 |
+
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
347 |
+
|
348 |
+
@classmethod
|
349 |
+
def writer(cls, path, dtype):
|
350 |
+
class _Writer(object):
|
351 |
+
def __enter__(self):
|
352 |
+
self._file = open(path, 'wb')
|
353 |
+
|
354 |
+
self._file.write(cls._HDR_MAGIC)
|
355 |
+
self._file.write(struct.pack('<Q', 1))
|
356 |
+
self._file.write(struct.pack('<B', code(dtype)))
|
357 |
+
|
358 |
+
return self
|
359 |
+
|
360 |
+
@staticmethod
|
361 |
+
def _get_pointers(sizes):
|
362 |
+
dtype_size = dtype().itemsize
|
363 |
+
address = 0
|
364 |
+
pointers = []
|
365 |
+
|
366 |
+
for size in sizes:
|
367 |
+
pointers.append(address)
|
368 |
+
address += size * dtype_size
|
369 |
+
|
370 |
+
return pointers
|
371 |
+
|
372 |
+
def write(self, sizes, doc_idx):
|
373 |
+
pointers = self._get_pointers(sizes)
|
374 |
+
|
375 |
+
self._file.write(struct.pack('<Q', len(sizes)))
|
376 |
+
self._file.write(struct.pack('<Q', len(doc_idx)))
|
377 |
+
|
378 |
+
sizes = np.array(sizes, dtype=np.int32)
|
379 |
+
self._file.write(sizes.tobytes(order='C'))
|
380 |
+
del sizes
|
381 |
+
|
382 |
+
pointers = np.array(pointers, dtype=np.int64)
|
383 |
+
self._file.write(pointers.tobytes(order='C'))
|
384 |
+
del pointers
|
385 |
+
|
386 |
+
doc_idx = np.array(doc_idx, dtype=np.int64)
|
387 |
+
self._file.write(doc_idx.tobytes(order='C'))
|
388 |
+
|
389 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
390 |
+
self._file.close()
|
391 |
+
|
392 |
+
return _Writer()
|
393 |
+
|
394 |
+
def __init__(self, path, skip_warmup=False):
|
395 |
+
with open(path, 'rb') as stream:
|
396 |
+
magic_test = stream.read(9)
|
397 |
+
assert self._HDR_MAGIC == magic_test, (
|
398 |
+
'Index file doesn\'t match expected format. '
|
399 |
+
'Make sure that --dataset-impl is configured properly.'
|
400 |
+
)
|
401 |
+
version = struct.unpack('<Q', stream.read(8))
|
402 |
+
assert (1,) == version
|
403 |
+
|
404 |
+
dtype_code, = struct.unpack('<B', stream.read(1))
|
405 |
+
self._dtype = dtypes[dtype_code]
|
406 |
+
self._dtype_size = self._dtype().itemsize
|
407 |
+
|
408 |
+
self._len = struct.unpack('<Q', stream.read(8))[0]
|
409 |
+
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
|
410 |
+
offset = stream.tell()
|
411 |
+
|
412 |
+
if not skip_warmup:
|
413 |
+
print_rank_0(" warming up index mmap file...")
|
414 |
+
_warmup_mmap_file(path)
|
415 |
+
|
416 |
+
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
417 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
418 |
+
print_rank_0(" reading sizes...")
|
419 |
+
self._sizes = np.frombuffer(
|
420 |
+
self._bin_buffer,
|
421 |
+
dtype=np.int32,
|
422 |
+
count=self._len,
|
423 |
+
offset=offset)
|
424 |
+
print_rank_0(" reading pointers...")
|
425 |
+
self._pointers = np.frombuffer(self._bin_buffer,
|
426 |
+
dtype=np.int64, count=self._len,
|
427 |
+
offset=offset + self._sizes.nbytes)
|
428 |
+
print_rank_0(" reading document index...")
|
429 |
+
self._doc_idx = np.frombuffer(
|
430 |
+
self._bin_buffer,
|
431 |
+
dtype=np.int64, count=self._doc_count,
|
432 |
+
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
|
433 |
+
|
434 |
+
def __del__(self):
|
435 |
+
self._bin_buffer_mmap._mmap.close()
|
436 |
+
del self._bin_buffer_mmap
|
437 |
+
|
438 |
+
@property
|
439 |
+
def dtype(self):
|
440 |
+
return self._dtype
|
441 |
+
|
442 |
+
@property
|
443 |
+
def sizes(self):
|
444 |
+
return self._sizes
|
445 |
+
|
446 |
+
@property
|
447 |
+
def doc_idx(self):
|
448 |
+
return self._doc_idx
|
449 |
+
|
450 |
+
@lru_cache(maxsize=8)
|
451 |
+
def __getitem__(self, i):
|
452 |
+
return self._pointers[i], self._sizes[i]
|
453 |
+
|
454 |
+
def __len__(self):
|
455 |
+
return self._len
|
456 |
+
|
457 |
+
def __init__(self, path, skip_warmup=False):
|
458 |
+
super().__init__()
|
459 |
+
|
460 |
+
self._path = None
|
461 |
+
self._index = None
|
462 |
+
self._bin_buffer = None
|
463 |
+
|
464 |
+
self._do_init(path, skip_warmup)
|
465 |
+
|
466 |
+
def __getstate__(self):
|
467 |
+
return self._path
|
468 |
+
|
469 |
+
def __setstate__(self, state):
|
470 |
+
self._do_init(state)
|
471 |
+
|
472 |
+
def _do_init(self, path, skip_warmup):
|
473 |
+
self._path = path
|
474 |
+
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
475 |
+
|
476 |
+
if not skip_warmup:
|
477 |
+
print_rank_0(" warming up data mmap file...")
|
478 |
+
_warmup_mmap_file(data_file_path(self._path))
|
479 |
+
print_rank_0(" creating numpy buffer of mmap...")
|
480 |
+
self._bin_buffer_mmap = np.memmap(
|
481 |
+
data_file_path(self._path), mode='r', order='C')
|
482 |
+
print_rank_0(" creating memory view of numpy buffer...")
|
483 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
484 |
+
|
485 |
+
def __del__(self):
|
486 |
+
self._bin_buffer_mmap._mmap.close()
|
487 |
+
del self._bin_buffer_mmap
|
488 |
+
del self._index
|
489 |
+
|
490 |
+
def __len__(self):
|
491 |
+
return len(self._index)
|
492 |
+
|
493 |
+
# @lru_cache(maxsize=8)
|
494 |
+
def __getitem__(self, idx):
|
495 |
+
if isinstance(idx, int):
|
496 |
+
ptr, size = self._index[idx]
|
497 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
498 |
+
count=size, offset=ptr)
|
499 |
+
return np_array
|
500 |
+
elif isinstance(idx, slice):
|
501 |
+
start, stop, step = idx.indices(len(self))
|
502 |
+
if step != 1:
|
503 |
+
raise ValueError(
|
504 |
+
"Slices into indexed_dataset must be contiguous")
|
505 |
+
ptr = self._index._pointers[start]
|
506 |
+
sizes = self._index._sizes[idx]
|
507 |
+
offsets = list(accumulate(sizes))
|
508 |
+
total_size = sum(sizes)
|
509 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
510 |
+
count=total_size, offset=ptr)
|
511 |
+
sents = np.split(np_array, offsets[:-1])
|
512 |
+
return sents
|
513 |
+
|
514 |
+
def get(self, idx, offset=0, length=None):
|
515 |
+
""" Retrieves a single item from the dataset with the option to only
|
516 |
+
return a portion of the item.
|
517 |
+
|
518 |
+
get(idx) is the same as [idx] but get() does not support slicing.
|
519 |
+
"""
|
520 |
+
ptr, size = self._index[idx]
|
521 |
+
if length is None:
|
522 |
+
length = size - offset
|
523 |
+
ptr += offset * np.dtype(self._index.dtype).itemsize
|
524 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
525 |
+
count=length, offset=ptr)
|
526 |
+
return np_array
|
527 |
+
|
528 |
+
@property
|
529 |
+
def sizes(self):
|
530 |
+
return self._index.sizes
|
531 |
+
|
532 |
+
@property
|
533 |
+
def doc_idx(self):
|
534 |
+
return self._index.doc_idx
|
535 |
+
|
536 |
+
def get_doc_idx(self):
|
537 |
+
return self._index._doc_idx
|
538 |
+
|
539 |
+
def set_doc_idx(self, doc_idx_):
|
540 |
+
self._index._doc_idx = doc_idx_
|
541 |
+
|
542 |
+
@property
|
543 |
+
def supports_prefetch(self):
|
544 |
+
return False
|
545 |
+
|
546 |
+
@staticmethod
|
547 |
+
def exists(path):
|
548 |
+
return (
|
549 |
+
os.path.exists(index_file_path(path)) and os.path.exists(
|
550 |
+
data_file_path(path))
|
551 |
+
)
|
552 |
+
|
553 |
+
|
554 |
+
class MMapIndexedDatasetBuilder(object):
|
555 |
+
def __init__(self, out_file, dtype=np.int64):
|
556 |
+
self._data_file = open(out_file, 'wb', buffering=5000000)
|
557 |
+
self._dtype = dtype
|
558 |
+
self._sizes = []
|
559 |
+
self._doc_idx = [0]
|
560 |
+
|
561 |
+
def add_item(self, tensor):
|
562 |
+
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
563 |
+
self._data_file.write(np_array.tobytes(order='C'))
|
564 |
+
self._sizes.append(np_array.size)
|
565 |
+
|
566 |
+
def end_document(self):
|
567 |
+
self._doc_idx.append(len(self._sizes))
|
568 |
+
|
569 |
+
def merge_file_(self, another_file):
|
570 |
+
# Concatenate index
|
571 |
+
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
572 |
+
assert index.dtype == self._dtype
|
573 |
+
|
574 |
+
for size in index.sizes:
|
575 |
+
self._sizes.append(size)
|
576 |
+
|
577 |
+
# Concatenate data
|
578 |
+
with open(data_file_path(another_file), 'rb') as f:
|
579 |
+
shutil.copyfileobj(f, self._data_file)
|
580 |
+
|
581 |
+
def finalize(self, index_file):
|
582 |
+
self._data_file.close()
|
583 |
+
|
584 |
+
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
585 |
+
index.write(self._sizes, self._doc_idx)
|
fengshen/data/megatron_dataloader/utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import torch
|
16 |
+
|
17 |
+
|
18 |
+
def print_rank_0(message):
|
19 |
+
"""If distributed is initialized, print only on rank 0."""
|
20 |
+
if torch.distributed.is_initialized():
|
21 |
+
if torch.distributed.get_rank() == 0:
|
22 |
+
print(message, flush=True)
|
23 |
+
else:
|
24 |
+
print(message, flush=True)
|
fengshen/data/mmap_dataloader/mmap_datamodule.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from pytorch_lightning import LightningDataModule
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from fengshen.data.mmap_index_dataset import MMapIndexDataset
|
5 |
+
|
6 |
+
|
7 |
+
class MMapDataModule(LightningDataModule):
|
8 |
+
@ staticmethod
|
9 |
+
def add_data_specific_args(parent_args):
|
10 |
+
parser = parent_args.add_argument_group('MMAP DataModule')
|
11 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
12 |
+
parser.add_argument('--train_batchsize', default=32, type=int)
|
13 |
+
parser.add_argument('--eval_batchsize', default=32, type=int)
|
14 |
+
parser.add_argument('--test_batchsize', default=32, type=int)
|
15 |
+
parser.add_argument('--train_datas', default=[
|
16 |
+
'./train_datas'
|
17 |
+
], type=str, nargs='+')
|
18 |
+
parser.add_argument('--valid_datas', default=[
|
19 |
+
'./valid_datas'
|
20 |
+
], type=str, nargs='+')
|
21 |
+
parser.add_argument('--test_datas', default=[
|
22 |
+
'./test_datas'],
|
23 |
+
type=str, nargs='+')
|
24 |
+
parser.add_argument('--input_tensor_name', default=['input_ids'], type=str, nargs='+')
|
25 |
+
return parent_args
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
collate_fn,
|
30 |
+
args,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.collate_fn = collate_fn
|
35 |
+
self.train_dataset = MMapIndexDataset(args.train_datas, args.input_tensor_name)
|
36 |
+
self.valid_dataset = MMapIndexDataset(args.valid_datas, args.input_tensor_name)
|
37 |
+
self.test_dataset = MMapIndexDataset(args.test_datas, args.input_tensor_name)
|
38 |
+
self.save_hyperparameters(args)
|
39 |
+
|
40 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
41 |
+
return super().setup(stage)
|
42 |
+
|
43 |
+
def train_dataloader(self):
|
44 |
+
return DataLoader(
|
45 |
+
self.train_dataset,
|
46 |
+
batch_size=self.hparams.train_batchsize,
|
47 |
+
shuffle=True,
|
48 |
+
num_workers=self.hparams.num_workers,
|
49 |
+
collate_fn=self.collate_fn,
|
50 |
+
)
|
51 |
+
|
52 |
+
def val_dataloader(self):
|
53 |
+
return DataLoader(
|
54 |
+
self.valid_dataset,
|
55 |
+
batch_size=self.hparams.eval_batchsize,
|
56 |
+
shuffle=True,
|
57 |
+
num_workers=self.hparams.num_workers,
|
58 |
+
collate_fn=self.collate_fn,
|
59 |
+
)
|
60 |
+
|
61 |
+
def test_dataloader(self):
|
62 |
+
return DataLoader(
|
63 |
+
self.test_dataset,
|
64 |
+
batch_size=self.hparams.test_batchsize,
|
65 |
+
shuffle=True,
|
66 |
+
num_workers=self.hparams.num_workers,
|
67 |
+
collate_fn=self.collate_fn,
|
68 |
+
)
|
fengshen/data/mmap_dataloader/mmap_index_dataset.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from typing import List
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
|
7 |
+
class MMapIndexDataset(Dataset):
|
8 |
+
# datapaths 是所有的内存映射文件的路径
|
9 |
+
# input_tensor_name 是输入的tensor的名字 例如 ['input_ids'] 会存储在对应的文件里面
|
10 |
+
def __init__(self, datapaths: List[str], input_tensor_name: List[str]):
|
11 |
+
dict_idx_fp = {}
|
12 |
+
dict_bin_fp = {}
|
13 |
+
idx_len = []
|
14 |
+
for tensor_name in input_tensor_name:
|
15 |
+
idx_fp = []
|
16 |
+
bin_fp = []
|
17 |
+
len = 0
|
18 |
+
for data_path in datapaths:
|
19 |
+
idx_fp += [np.load(
|
20 |
+
data_path + '_' + tensor_name + '.npy', mmap_mode='r')]
|
21 |
+
bin_fp += [np.memmap(
|
22 |
+
data_path + '_' + tensor_name + '.bin',
|
23 |
+
dtype='long',
|
24 |
+
mode='r')]
|
25 |
+
len += idx_fp[-1].shape[0]
|
26 |
+
idx_len += [idx_fp[-1].shape[0]]
|
27 |
+
dict_idx_fp[tensor_name] = idx_fp
|
28 |
+
dict_bin_fp[tensor_name] = bin_fp
|
29 |
+
# 通常情况下不同的tensor的长度是一样的
|
30 |
+
self._len = len
|
31 |
+
|
32 |
+
self._input_tensor_name = input_tensor_name
|
33 |
+
self._dict_idx_fp = dict_idx_fp
|
34 |
+
self._dict_bin_fp = dict_bin_fp
|
35 |
+
self._idx_len = idx_len
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return self._len
|
39 |
+
|
40 |
+
def __getitem__(self, idx):
|
41 |
+
sample = {}
|
42 |
+
for i in range(len(self._idx_len)):
|
43 |
+
if idx >= self._idx_len[i]:
|
44 |
+
idx -= self._idx_len[i]
|
45 |
+
else:
|
46 |
+
break
|
47 |
+
for tensor_name in self._input_tensor_name:
|
48 |
+
sample[tensor_name] = torch.tensor(self._dict_bin_fp[tensor_name][i][
|
49 |
+
self._dict_idx_fp[tensor_name][i][idx, 0]:
|
50 |
+
self._dict_idx_fp[tensor_name][i][idx, 1]
|
51 |
+
], dtype=torch.long)
|
52 |
+
# print(sample)
|
53 |
+
return sample
|
fengshen/data/preprocess.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# coding=utf-8
|
fengshen/data/sequence_tagging_dataloader/sequence_tagging_collator.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from torch.utils.data._utils.collate import default_collate
|
3 |
+
|
4 |
+
import copy
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class CollatorForLinear:
|
10 |
+
args = None
|
11 |
+
tokenizer = None
|
12 |
+
label2id = None
|
13 |
+
|
14 |
+
def __call__(self, samples):
|
15 |
+
cls_token = "[CLS]"
|
16 |
+
sep_token = "[SEP]"
|
17 |
+
pad_token = 0
|
18 |
+
special_tokens_count = 2
|
19 |
+
segment_id = 0
|
20 |
+
|
21 |
+
features=[]
|
22 |
+
|
23 |
+
for (ex_index, example) in enumerate(samples):
|
24 |
+
tokens = copy.deepcopy(example['text_a'])
|
25 |
+
|
26 |
+
label_ids = [self.label2id[x] for x in example['labels']]
|
27 |
+
|
28 |
+
if len(tokens) > self.args.max_seq_length - special_tokens_count:
|
29 |
+
tokens = tokens[: (self.args.max_seq_length - special_tokens_count)]
|
30 |
+
label_ids = label_ids[: (self.args.max_seq_length - special_tokens_count)]
|
31 |
+
|
32 |
+
tokens += [sep_token]
|
33 |
+
label_ids += [self.label2id["O"]]
|
34 |
+
segment_ids = [segment_id] * len(tokens)
|
35 |
+
|
36 |
+
tokens = [cls_token] + tokens
|
37 |
+
label_ids = [self.label2id["O"]] + label_ids
|
38 |
+
segment_ids = [segment_id] + segment_ids
|
39 |
+
|
40 |
+
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
41 |
+
input_mask = [1] * len(input_ids)
|
42 |
+
input_len = len(label_ids)
|
43 |
+
padding_length = self.args.max_seq_length - len(input_ids)
|
44 |
+
|
45 |
+
input_ids += [pad_token] * padding_length
|
46 |
+
input_mask += [0] * padding_length
|
47 |
+
segment_ids += [segment_id] * padding_length
|
48 |
+
label_ids += [pad_token] * padding_length
|
49 |
+
|
50 |
+
assert len(input_ids) == self.args.max_seq_length
|
51 |
+
assert len(input_mask) == self.args.max_seq_length
|
52 |
+
assert len(segment_ids) == self.args.max_seq_length
|
53 |
+
assert len(label_ids) == self.args.max_seq_length
|
54 |
+
|
55 |
+
features.append({
|
56 |
+
'input_ids':torch.tensor(input_ids),
|
57 |
+
'attention_mask':torch.tensor(input_mask),
|
58 |
+
'input_len':torch.tensor(input_len),
|
59 |
+
'token_type_ids':torch.tensor(segment_ids),
|
60 |
+
'labels':torch.tensor(label_ids),
|
61 |
+
})
|
62 |
+
|
63 |
+
return default_collate(features)
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class CollatorForCrf:
|
67 |
+
args = None
|
68 |
+
tokenizer = None
|
69 |
+
label2id = None
|
70 |
+
|
71 |
+
def __call__(self, samples):
|
72 |
+
features = []
|
73 |
+
cls_token = "[CLS]"
|
74 |
+
sep_token = "[SEP]"
|
75 |
+
pad_token = 0
|
76 |
+
special_tokens_count = 2
|
77 |
+
segment_id = 0
|
78 |
+
|
79 |
+
for (ex_index, example) in enumerate(samples):
|
80 |
+
tokens = copy.deepcopy(example['text_a'])
|
81 |
+
|
82 |
+
label_ids = [self.label2id[x] for x in example['labels']]
|
83 |
+
|
84 |
+
if len(tokens) > self.args.max_seq_length - special_tokens_count:
|
85 |
+
tokens = tokens[: (self.args.max_seq_length - special_tokens_count)]
|
86 |
+
label_ids = label_ids[: (self.args.max_seq_length - special_tokens_count)]
|
87 |
+
|
88 |
+
tokens += [sep_token]
|
89 |
+
label_ids += [self.label2id["O"]]
|
90 |
+
segment_ids = [segment_id] * len(tokens)
|
91 |
+
|
92 |
+
tokens = [cls_token] + tokens
|
93 |
+
label_ids = [self.label2id["O"]] + label_ids
|
94 |
+
segment_ids = [segment_id] + segment_ids
|
95 |
+
|
96 |
+
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
97 |
+
input_mask = [1] * len(input_ids)
|
98 |
+
input_len = len(label_ids)
|
99 |
+
padding_length = self.args.max_seq_length - len(input_ids)
|
100 |
+
|
101 |
+
input_ids += [pad_token] * padding_length
|
102 |
+
input_mask += [0] * padding_length
|
103 |
+
segment_ids += [segment_id] * padding_length
|
104 |
+
label_ids += [pad_token] * padding_length
|
105 |
+
|
106 |
+
assert len(input_ids) == self.args.max_seq_length
|
107 |
+
assert len(input_mask) == self.args.max_seq_length
|
108 |
+
assert len(segment_ids) == self.args.max_seq_length
|
109 |
+
assert len(label_ids) == self.args.max_seq_length
|
110 |
+
|
111 |
+
features.append({
|
112 |
+
'input_ids':torch.tensor(input_ids),
|
113 |
+
'attention_mask':torch.tensor(input_mask),
|
114 |
+
'input_len':torch.tensor(input_len),
|
115 |
+
'token_type_ids':torch.tensor(segment_ids),
|
116 |
+
'labels':torch.tensor(label_ids),
|
117 |
+
})
|
118 |
+
|
119 |
+
return default_collate(features)
|
120 |
+
|
121 |
+
|
122 |
+
@dataclass
|
123 |
+
class CollatorForSpan:
|
124 |
+
args = None
|
125 |
+
tokenizer = None
|
126 |
+
label2id = None
|
127 |
+
|
128 |
+
def __call__(self, samples):
|
129 |
+
|
130 |
+
features = []
|
131 |
+
cls_token = "[CLS]"
|
132 |
+
sep_token = "[SEP]"
|
133 |
+
pad_token = 0
|
134 |
+
special_tokens_count = 2
|
135 |
+
max_entities_count = 100
|
136 |
+
segment_id = 0
|
137 |
+
|
138 |
+
for (ex_index, example) in enumerate(samples):
|
139 |
+
subjects = copy.deepcopy(example['subject'])
|
140 |
+
tokens = copy.deepcopy(example['text_a'])
|
141 |
+
start_ids = [0] * len(tokens)
|
142 |
+
end_ids = [0] * len(tokens)
|
143 |
+
subject_ids = []
|
144 |
+
for subject in subjects:
|
145 |
+
label = subject[0]
|
146 |
+
start = subject[1]
|
147 |
+
end = subject[2]
|
148 |
+
start_ids[start] = self.label2id[label]
|
149 |
+
end_ids[end] = self.label2id[label]
|
150 |
+
subject_ids.append([self.label2id[label], start, end])
|
151 |
+
|
152 |
+
subject_ids+=[[-1,-1,-1]]*(max_entities_count-len(subject_ids))
|
153 |
+
|
154 |
+
if len(tokens) > self.args.max_seq_length - special_tokens_count:
|
155 |
+
tokens = tokens[: (self.args.max_seq_length - special_tokens_count)]
|
156 |
+
start_ids = start_ids[: (self.args.max_seq_length - special_tokens_count)]
|
157 |
+
end_ids = end_ids[: (self.args.max_seq_length - special_tokens_count)]
|
158 |
+
|
159 |
+
tokens += [sep_token]
|
160 |
+
start_ids += [0]
|
161 |
+
end_ids += [0]
|
162 |
+
segment_ids = [segment_id] * len(tokens)
|
163 |
+
|
164 |
+
tokens = [cls_token] + tokens
|
165 |
+
start_ids = [0] + start_ids
|
166 |
+
end_ids = [0] + end_ids
|
167 |
+
segment_ids = [segment_id] + segment_ids
|
168 |
+
|
169 |
+
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
170 |
+
input_mask = [1] * len(input_ids)
|
171 |
+
input_len = len(input_ids)
|
172 |
+
padding_length = self.args.max_seq_length - len(input_ids)
|
173 |
+
|
174 |
+
input_ids += [pad_token] * padding_length
|
175 |
+
input_mask += [0] * padding_length
|
176 |
+
segment_ids += [segment_id] * padding_length
|
177 |
+
start_ids += [0] * padding_length
|
178 |
+
end_ids += [0] * padding_length
|
179 |
+
|
180 |
+
assert len(input_ids) == self.args.max_seq_length
|
181 |
+
assert len(input_mask) == self.args.max_seq_length
|
182 |
+
assert len(segment_ids) == self.args.max_seq_length
|
183 |
+
assert len(start_ids) == self.args.max_seq_length
|
184 |
+
assert len(end_ids) == self.args.max_seq_length
|
185 |
+
|
186 |
+
features.append({
|
187 |
+
'input_ids': torch.tensor(np.array(input_ids)),
|
188 |
+
'attention_mask': torch.tensor(np.array(input_mask)),
|
189 |
+
'token_type_ids': torch.tensor(np.array(segment_ids)),
|
190 |
+
'start_positions': torch.tensor(np.array(start_ids)),
|
191 |
+
'end_positions': torch.tensor(np.array(end_ids)),
|
192 |
+
"subjects": torch.tensor(np.array(subject_ids)),
|
193 |
+
'input_len': torch.tensor(np.array(input_len)),
|
194 |
+
})
|
195 |
+
|
196 |
+
return default_collate(features)
|
197 |
+
|
198 |
+
|
199 |
+
@dataclass
|
200 |
+
class CollatorForBiaffine:
|
201 |
+
args = None
|
202 |
+
tokenizer = None
|
203 |
+
label2id = None
|
204 |
+
|
205 |
+
|
206 |
+
def __call__(self, samples):
|
207 |
+
|
208 |
+
features = []
|
209 |
+
cls_token = "[CLS]"
|
210 |
+
sep_token = "[SEP]"
|
211 |
+
pad_token = 0
|
212 |
+
special_tokens_count = 2
|
213 |
+
segment_id = 0
|
214 |
+
|
215 |
+
for (ex_index, example) in enumerate(samples):
|
216 |
+
subjects = copy.deepcopy(example['subject'])
|
217 |
+
tokens = copy.deepcopy(example['text_a'])
|
218 |
+
|
219 |
+
span_labels = np.zeros((self.args.max_seq_length,self.args.max_seq_length))
|
220 |
+
span_labels[:] = self.label2id["O"]
|
221 |
+
|
222 |
+
for subject in subjects:
|
223 |
+
label = subject[0]
|
224 |
+
start = subject[1]
|
225 |
+
end = subject[2]
|
226 |
+
if start < self.args.max_seq_length - special_tokens_count and end < self.args.max_seq_length - special_tokens_count:
|
227 |
+
span_labels[start + 1, end + 1] = self.label2id[label]
|
228 |
+
|
229 |
+
if len(tokens) > self.args.max_seq_length - special_tokens_count:
|
230 |
+
tokens = tokens[: (self.args.max_seq_length - special_tokens_count)]
|
231 |
+
|
232 |
+
tokens += [sep_token]
|
233 |
+
span_labels[len(tokens), :] = self.label2id["O"]
|
234 |
+
span_labels[:, len(tokens)] = self.label2id["O"]
|
235 |
+
segment_ids = [segment_id] * len(tokens)
|
236 |
+
|
237 |
+
tokens = [cls_token] + tokens
|
238 |
+
span_labels[0, :] = self.label2id["O"]
|
239 |
+
span_labels[:, 0] = self.label2id["O"]
|
240 |
+
segment_ids = [segment_id] + segment_ids
|
241 |
+
|
242 |
+
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
243 |
+
input_mask = [0] * len(input_ids)
|
244 |
+
span_mask = np.ones(span_labels.shape)
|
245 |
+
input_len = len(input_ids)
|
246 |
+
|
247 |
+
padding_length = self.args.max_seq_length - len(input_ids)
|
248 |
+
|
249 |
+
input_ids += [pad_token] * padding_length
|
250 |
+
input_mask += [0] * padding_length
|
251 |
+
segment_ids += [segment_id] * padding_length
|
252 |
+
span_labels[input_len:, :] = 0
|
253 |
+
span_labels[:, input_len:] = 0
|
254 |
+
span_mask[input_len:, :] = 0
|
255 |
+
span_mask[:, input_len:] = 0
|
256 |
+
span_mask=np.triu(span_mask,0)
|
257 |
+
span_mask=np.tril(span_mask,10)
|
258 |
+
|
259 |
+
assert len(input_ids) == self.args.max_seq_length
|
260 |
+
assert len(input_mask) == self.args.max_seq_length
|
261 |
+
assert len(segment_ids) == self.args.max_seq_length
|
262 |
+
assert len(span_labels) == self.args.max_seq_length
|
263 |
+
assert len(span_labels[0]) == self.args.max_seq_length
|
264 |
+
|
265 |
+
features.append({
|
266 |
+
'input_ids': torch.tensor(np.array(input_ids)),
|
267 |
+
'attention_mask': torch.tensor(np.array(input_mask)),
|
268 |
+
'token_type_ids': torch.tensor(np.array(segment_ids)),
|
269 |
+
'span_labels': torch.tensor(np.array(span_labels)),
|
270 |
+
'span_mask': torch.tensor(np.array(span_mask)),
|
271 |
+
'input_len': torch.tensor(np.array(input_len)),
|
272 |
+
})
|
273 |
+
|
274 |
+
return default_collate(features)
|
fengshen/data/sequence_tagging_dataloader/sequence_tagging_datasets.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from fengshen.metric.utils_ner import get_entities
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
def get_datasets(args):
|
7 |
+
processor = DataProcessor(args.data_dir, args.decode_type)
|
8 |
+
|
9 |
+
train_data = TaskDataset(processor=processor, mode="train")
|
10 |
+
valid_data = TaskDataset(processor=processor, mode="dev")
|
11 |
+
test_data = TaskDataset(processor=processor, mode="dev")
|
12 |
+
|
13 |
+
return {"train":train_data,"validation":valid_data,"test":test_data}
|
14 |
+
|
15 |
+
# def get_labels(decode_type):
|
16 |
+
# with open("/cognitive_comp/lujunyu/data_zh/NER_Aligned/weibo/labels.txt") as f:
|
17 |
+
# label_list = ["[PAD]", "[START]", "[END]"]
|
18 |
+
|
19 |
+
# if decode_type=="crf" or decode_type=="linear":
|
20 |
+
# for line in f.readlines():
|
21 |
+
# label_list.append(line.strip())
|
22 |
+
# elif decode_type=="biaffine" or decode_type=="span":
|
23 |
+
# for line in f.readlines():
|
24 |
+
# tag = line.strip().split("-")
|
25 |
+
# if len(tag) == 1 and tag[0] not in label_list:
|
26 |
+
# label_list.append(tag[0])
|
27 |
+
# elif tag[1] not in label_list:
|
28 |
+
# label_list.append(tag[1])
|
29 |
+
|
30 |
+
# label2id={label:id for id,label in enumerate(label_list)}
|
31 |
+
# id2label={id:label for id,label in enumerate(label_list)}
|
32 |
+
# return label2id, id2label
|
33 |
+
|
34 |
+
class DataProcessor(object):
|
35 |
+
def __init__(self, data_dir, decode_type) -> None:
|
36 |
+
super().__init__()
|
37 |
+
self.data_dir = data_dir
|
38 |
+
self.decode_type = decode_type
|
39 |
+
|
40 |
+
def get_examples(self, mode):
|
41 |
+
return self._create_examples(self._read_text(os.path.join(self.data_dir, mode + ".all.bmes")), mode)
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def get_labels(args):
|
45 |
+
with open(os.path.join(args.data_dir, "labels.txt")) as f:
|
46 |
+
label_list = ["[PAD]", "[START]", "[END]"]
|
47 |
+
|
48 |
+
if args.decode_type=="crf" or args.decode_type=="linear":
|
49 |
+
for line in f.readlines():
|
50 |
+
label_list.append(line.strip())
|
51 |
+
elif args.decode_type=="biaffine" or args.decode_type=="span":
|
52 |
+
for line in f.readlines():
|
53 |
+
tag = line.strip().split("-")
|
54 |
+
if len(tag) == 1 and tag[0] not in label_list:
|
55 |
+
label_list.append(tag[0])
|
56 |
+
elif tag[1] not in label_list:
|
57 |
+
label_list.append(tag[1])
|
58 |
+
|
59 |
+
label2id = {label: i for i, label in enumerate(label_list)}
|
60 |
+
id2label={id:label for id,label in enumerate(label_list)}
|
61 |
+
return label2id,id2label
|
62 |
+
|
63 |
+
def _create_examples(self, lines, set_type):
|
64 |
+
examples = []
|
65 |
+
for (i, line) in enumerate(lines):
|
66 |
+
guid = "%s-%s" % (set_type, i)
|
67 |
+
text_a = line['words']
|
68 |
+
labels = []
|
69 |
+
for x in line['labels']:
|
70 |
+
if 'M-' in x:
|
71 |
+
labels.append(x.replace('M-', 'I-'))
|
72 |
+
else:
|
73 |
+
labels.append(x)
|
74 |
+
subject = get_entities(labels, id2label=None, markup='bioes')
|
75 |
+
examples.append({'guid':guid, 'text_a':text_a, 'labels':labels, 'subject':subject})
|
76 |
+
return examples
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def _read_text(self, input_file):
|
80 |
+
lines = []
|
81 |
+
with open(input_file, 'r') as f:
|
82 |
+
words = []
|
83 |
+
labels = []
|
84 |
+
for line in f:
|
85 |
+
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
86 |
+
if words:
|
87 |
+
lines.append({"words": words, "labels": labels})
|
88 |
+
words = []
|
89 |
+
labels = []
|
90 |
+
else:
|
91 |
+
splits = line.split()
|
92 |
+
words.append(splits[0])
|
93 |
+
if len(splits) > 1:
|
94 |
+
labels.append(splits[-1].replace("\n", ""))
|
95 |
+
else:
|
96 |
+
# Examples could have no label for mode = "test"
|
97 |
+
labels.append("O")
|
98 |
+
if words:
|
99 |
+
lines.append({"words": words, "labels": labels})
|
100 |
+
return lines
|
101 |
+
|
102 |
+
|
103 |
+
class TaskDataset(Dataset):
|
104 |
+
def __init__(self, processor, mode='train'):
|
105 |
+
super().__init__()
|
106 |
+
self.data = self.load_data(processor, mode)
|
107 |
+
|
108 |
+
def __len__(self):
|
109 |
+
return len(self.data)
|
110 |
+
|
111 |
+
def __getitem__(self, index):
|
112 |
+
return self.data[index]
|
113 |
+
|
114 |
+
def load_data(self, processor, mode):
|
115 |
+
examples = processor.get_examples(mode)
|
116 |
+
return examples
|
fengshen/data/t5_dataloader/t5_datasets.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf8
|
2 |
+
import json
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from tqdm import tqdm
|
5 |
+
from transformers import BertTokenizer, MT5Config, MT5Tokenizer, BatchEncoding
|
6 |
+
import torch
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
import numpy as np
|
9 |
+
from itertools import chain
|
10 |
+
import sys
|
11 |
+
sys.path.append('../../')
|
12 |
+
|
13 |
+
|
14 |
+
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
|
15 |
+
"""This function is copy of `random_spans_helper <https://github.com/google-research/
|
16 |
+
text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
|
17 |
+
Training parameters to avoid padding with random_spans_noise_mask.
|
18 |
+
When training a model with random_spans_noise_mask, we would like to set the other
|
19 |
+
training hyperparmeters in a way that avoids padding.
|
20 |
+
This function helps us compute these hyperparameters.
|
21 |
+
We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
|
22 |
+
and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
|
23 |
+
This function tells us the required number of tokens in the raw example (for split_tokens())
|
24 |
+
as well as the length of the encoded targets. Note that this function assumes
|
25 |
+
the inputs and targets will have EOS appended and includes that in the reported length.
|
26 |
+
Args:
|
27 |
+
inputs_length: an integer - desired length of the tokenized inputs sequence
|
28 |
+
noise_density: a float
|
29 |
+
mean_noise_span_length: a float
|
30 |
+
Returns:
|
31 |
+
tokens_length: length of original text in tokens
|
32 |
+
targets_length: an integer - length in tokens of encoded targets sequence
|
33 |
+
"""
|
34 |
+
|
35 |
+
def _tokens_length_to_inputs_length_targets_length(tokens_length):
|
36 |
+
num_noise_tokens = int(round(tokens_length * noise_density))
|
37 |
+
num_nonnoise_tokens = tokens_length - num_noise_tokens
|
38 |
+
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
|
39 |
+
# inputs contain all nonnoise tokens, sentinels for all noise spans
|
40 |
+
# and one EOS token.
|
41 |
+
_input_length = num_nonnoise_tokens + num_noise_spans + 1
|
42 |
+
_output_length = num_noise_tokens + num_noise_spans + 1
|
43 |
+
return _input_length, _output_length
|
44 |
+
|
45 |
+
tokens_length = inputs_length
|
46 |
+
|
47 |
+
while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
|
48 |
+
tokens_length += 1
|
49 |
+
|
50 |
+
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(
|
51 |
+
tokens_length)
|
52 |
+
|
53 |
+
# minor hack to get the targets length to be equal to inputs length
|
54 |
+
# which is more likely to have been set to a nice round number.
|
55 |
+
if noise_density == 0.5 and targets_length > inputs_length:
|
56 |
+
tokens_length -= 1
|
57 |
+
targets_length -= 1
|
58 |
+
return tokens_length, targets_length
|
59 |
+
|
60 |
+
|
61 |
+
class UnsuperviseT5Dataset(Dataset):
|
62 |
+
'''
|
63 |
+
Dataset Used for T5 unsuprvise pretrain.
|
64 |
+
load_data_type = 0: load raw data from data path and save tokenized data, call function load_data
|
65 |
+
load_data_type = 1: load tokenized data from path, call function load_tokenized_data
|
66 |
+
load_data_type = 2: load tokenized data from memery data, call function load_tokenized_memory_data
|
67 |
+
'''
|
68 |
+
|
69 |
+
def __init__(self, data_path, args, load_data_type=0, data=None):
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
if args.tokenizer_type == 't5_tokenizer':
|
73 |
+
if args.new_vocab_path is not None:
|
74 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(args.new_vocab_path)
|
75 |
+
else:
|
76 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(args.pretrained_model_path)
|
77 |
+
else:
|
78 |
+
self.tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_path)
|
79 |
+
self.noise_density = 0.15
|
80 |
+
self.mean_noise_span_length = 3
|
81 |
+
self.text_column_name = args.text_column_name
|
82 |
+
self.dataset_num_workers = args.dataset_num_workers
|
83 |
+
self.max_seq_length = args.max_seq_length
|
84 |
+
self.remove_columns = args.remove_columns
|
85 |
+
# whether load tokenieze data
|
86 |
+
self.load_data_type = load_data_type
|
87 |
+
|
88 |
+
if self.load_data_type == 0:
|
89 |
+
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
90 |
+
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
|
91 |
+
# according to `mlm_probability` and `mean_noise_span_length`.
|
92 |
+
# We can also define the label length accordingly.
|
93 |
+
self.expanded_inputs_length, self.targets_length = compute_input_and_target_lengths(
|
94 |
+
inputs_length=self.max_seq_length,
|
95 |
+
noise_density=self.noise_density,
|
96 |
+
mean_noise_span_length=self.mean_noise_span_length,
|
97 |
+
)
|
98 |
+
print('self.expanded_inputs_length, self.targets_length:{},{}'.format(
|
99 |
+
self.expanded_inputs_length, self.targets_length))
|
100 |
+
self.data = self.load_data(data_path)
|
101 |
+
elif self.load_data_type == 1:
|
102 |
+
self.data = self.load_tokenized_data(data_path)
|
103 |
+
else:
|
104 |
+
assert data is not None
|
105 |
+
self.data = self.load_tokenized_memory_data(data)
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return len(self.data)
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
return self.data[index]
|
112 |
+
|
113 |
+
def load_data(self, data_path):
|
114 |
+
# TODO: large data process
|
115 |
+
from data.fs_datasets import load_dataset
|
116 |
+
samples = load_dataset(
|
117 |
+
# samples = datasets.load_from_disk(data_path)['train']
|
118 |
+
data_path, num_proc=self.dataset_num_workers)['train']
|
119 |
+
# print(samples)
|
120 |
+
tokenized_datasets = samples.map(
|
121 |
+
self.tokenize_function,
|
122 |
+
batched=True,
|
123 |
+
num_proc=self.dataset_num_workers,
|
124 |
+
# load_from_cache_file=not data_args.overwrite_cache,
|
125 |
+
).map(
|
126 |
+
batched=True,
|
127 |
+
num_proc=self.dataset_num_workers,
|
128 |
+
remove_columns=self.remove_columns)
|
129 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
130 |
+
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
131 |
+
# might be slower to preprocess.
|
132 |
+
#
|
133 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
134 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
135 |
+
tokenized_datasets = tokenized_datasets.map(
|
136 |
+
self.group_texts,
|
137 |
+
batched=True,
|
138 |
+
num_proc=self.dataset_num_workers,
|
139 |
+
# load_from_cache_file=not data_args.overwrite_cache,
|
140 |
+
)
|
141 |
+
return tokenized_datasets
|
142 |
+
'''
|
143 |
+
The function load tokenized data saved from load_data function.
|
144 |
+
'''
|
145 |
+
|
146 |
+
def load_tokenized_data(self, data_path):
|
147 |
+
from data.fs_datasets import load_dataset
|
148 |
+
samples = load_dataset(data_path)['train']
|
149 |
+
return samples
|
150 |
+
|
151 |
+
def load_tokenized_memory_data(self, data):
|
152 |
+
return data
|
153 |
+
|
154 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
155 |
+
# Since we make sure that all sequences are of the same length, no attention_mask is needed.
|
156 |
+
def tokenize_function(self, examples):
|
157 |
+
# 这里add_special_tokens=False,避免句子中间出现eos
|
158 |
+
return self.tokenizer(examples[self.text_column_name],
|
159 |
+
add_special_tokens=False,
|
160 |
+
return_attention_mask=False)
|
161 |
+
|
162 |
+
# Main data processing function that will concatenate all texts from our dataset
|
163 |
+
# and generate chunks of expanded_inputs_length.
|
164 |
+
def group_texts(self, examples):
|
165 |
+
# Concatenate all texts.
|
166 |
+
concatenated_examples = {
|
167 |
+
k: list(chain(*examples[k])) for k in examples.keys()}
|
168 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
169 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
170 |
+
# customize this part to your needs.
|
171 |
+
if total_length >= self.expanded_inputs_length:
|
172 |
+
total_length = (
|
173 |
+
total_length // self.expanded_inputs_length) * self.expanded_inputs_length
|
174 |
+
# Split by chunks of max_len.
|
175 |
+
result = {
|
176 |
+
k: [t[i: i + self.expanded_inputs_length]
|
177 |
+
for i in range(0, total_length, self.expanded_inputs_length)]
|
178 |
+
for k, t in concatenated_examples.items()
|
179 |
+
}
|
180 |
+
return result
|
181 |
+
|
182 |
+
|
183 |
+
class UnsuperviseT5DataModel(pl.LightningDataModule):
|
184 |
+
@staticmethod
|
185 |
+
def add_data_specific_args(parent_args):
|
186 |
+
parser = parent_args.add_argument_group('UnsuperviseT5DataModel')
|
187 |
+
parser.add_argument('--dataset_num_workers', default=8, type=int)
|
188 |
+
parser.add_argument('--dataloader_num_workers', default=4, type=int)
|
189 |
+
parser.add_argument(
|
190 |
+
'--train_data_path', default='wudao_180g_mt5_tokenized', type=str)
|
191 |
+
parser.add_argument('--train_batchsize', default=2, type=int)
|
192 |
+
parser.add_argument('--valid_batchsize', default=2, type=int)
|
193 |
+
parser.add_argument('--train_split_size', default=None, type=float)
|
194 |
+
parser.add_argument('--tokenizer_type', default='t5_tokenizer', choices=['t5_tokenizer', 'bert_tokenizer'])
|
195 |
+
parser.add_argument('--text_column_name', default='text')
|
196 |
+
parser.add_argument('--remove_columns', nargs='+', default=[])
|
197 |
+
return parent_args
|
198 |
+
|
199 |
+
def __init__(self, args):
|
200 |
+
super().__init__()
|
201 |
+
self.save_hyperparameters(args)
|
202 |
+
if args.train_split_size is not None:
|
203 |
+
from data.fs_datasets import load_dataset
|
204 |
+
data_splits = load_dataset(args.train_data_path, num_proc=args.dataset_num_workers)
|
205 |
+
train_split = data_splits['train']
|
206 |
+
test_split = data_splits['test']
|
207 |
+
print('train:', train_split, '\ntest_data:', test_split)
|
208 |
+
self.train_dataset = UnsuperviseT5Dataset('', args, load_data_type=2, data=train_split)
|
209 |
+
self.test_dataset = UnsuperviseT5Dataset('', args, load_data_type=2, data=test_split)
|
210 |
+
else:
|
211 |
+
self.train_data = UnsuperviseT5Dataset(args.train_data_path, args, load_data_type=1)
|
212 |
+
|
213 |
+
self.config = MT5Config.from_pretrained(args.pretrained_model_path)
|
214 |
+
self.noise_density = 0.15
|
215 |
+
self.mean_noise_span_length = 3
|
216 |
+
self.pad_token_id = self.config.pad_token_id
|
217 |
+
self.decoder_start_token_id = self.config.decoder_start_token_id
|
218 |
+
self.eos_token_id = self.config.eos_token_id
|
219 |
+
self.vocab_size = self.config.vocab_size
|
220 |
+
self.max_seq_length = args.max_seq_length
|
221 |
+
# 因为加载旧的spm里面已经包括了exrta_ids,但是T5Tokenizer会在spm的基础上再增加100个extra_ids,所以需要指定extra_ids=0
|
222 |
+
if args.tokenizer_type == 't5_tokenizer' and args.new_vocab_path is not None:
|
223 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(args.new_vocab_path, extra_ids=0)
|
224 |
+
# 如果是刚开始加载mt5,需要更新vocab_size为提取中英词之后的new_vocab_size
|
225 |
+
self.vocab_size = len(self.tokenizer)
|
226 |
+
|
227 |
+
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
228 |
+
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
|
229 |
+
# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
|
230 |
+
self.expanded_inputs_length, self.targets_length = compute_input_and_target_lengths(
|
231 |
+
inputs_length=self.max_seq_length,
|
232 |
+
noise_density=self.noise_density,
|
233 |
+
mean_noise_span_length=self.mean_noise_span_length,
|
234 |
+
)
|
235 |
+
|
236 |
+
def train_dataloader(self):
|
237 |
+
from fengshen.data.universal_datamodule.universal_sampler import PretrainingSampler
|
238 |
+
from fengshen.data.universal_datamodule.universal_datamodule import get_consume_samples
|
239 |
+
# 采用自定义的sampler,确保继续训练能正确取到数据
|
240 |
+
consumed_samples = get_consume_samples(self)
|
241 |
+
batch_sampler = PretrainingSampler(
|
242 |
+
total_samples=len(self.train_dataset),
|
243 |
+
consumed_samples=consumed_samples,
|
244 |
+
micro_batch_size=self.hparams.train_batchsize,
|
245 |
+
data_parallel_rank=self.trainer.global_rank,
|
246 |
+
data_parallel_size=self.trainer.world_size,
|
247 |
+
)
|
248 |
+
return DataLoader(
|
249 |
+
self.train_dataset,
|
250 |
+
batch_sampler=batch_sampler,
|
251 |
+
pin_memory=True,
|
252 |
+
num_workers=self.hparams.dataloader_num_workers,
|
253 |
+
collate_fn=self.collate_fn,
|
254 |
+
)
|
255 |
+
|
256 |
+
def val_dataloader(self):
|
257 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
258 |
+
self.test_dataset, shuffle=False)
|
259 |
+
return DataLoader(
|
260 |
+
self.test_dataset,
|
261 |
+
sampler=sampler,
|
262 |
+
shuffle=False,
|
263 |
+
batch_size=self.hparams.valid_batchsize,
|
264 |
+
pin_memory=True,
|
265 |
+
num_workers=self.hparams.dataloader_num_workers,
|
266 |
+
collate_fn=self.collate_fn,
|
267 |
+
)
|
268 |
+
|
269 |
+
def predict_dataloader(self):
|
270 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
271 |
+
self.test_dataset, shuffle=False)
|
272 |
+
return DataLoader(
|
273 |
+
self.test_data,
|
274 |
+
sampler=sampler,
|
275 |
+
shuffle=False,
|
276 |
+
batch_size=self.hparams.valid_batchsize,
|
277 |
+
pin_memory=True,
|
278 |
+
num_workers=self.hparams.dataloader_num_workers,
|
279 |
+
collate_fn=self.collate_fn,
|
280 |
+
)
|
281 |
+
|
282 |
+
def collate_fn(self, examples):
|
283 |
+
# convert list to dict and tensorize input
|
284 |
+
batch = BatchEncoding(
|
285 |
+
{k: np.array([examples[i][k] for i in range(len(examples))])
|
286 |
+
for k, v in examples[0].items()}
|
287 |
+
)
|
288 |
+
|
289 |
+
input_ids = np.array(batch['input_ids'])
|
290 |
+
batch_size, expanded_input_length = input_ids.shape
|
291 |
+
mask_indices = np.asarray([self.random_spans_noise_mask(
|
292 |
+
expanded_input_length) for i in range(batch_size)])
|
293 |
+
labels_mask = ~mask_indices
|
294 |
+
|
295 |
+
input_ids_sentinel = self.create_sentinel_ids(
|
296 |
+
mask_indices.astype(np.int8))
|
297 |
+
labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
|
298 |
+
|
299 |
+
batch["input_ids"] = self.filter_input_ids(
|
300 |
+
input_ids, input_ids_sentinel)
|
301 |
+
batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
|
302 |
+
|
303 |
+
if batch["input_ids"].shape[-1] != self.max_seq_length:
|
304 |
+
raise ValueError(
|
305 |
+
f"`input_ids` are incorrectly preprocessed. `input_ids` length is \
|
306 |
+
{batch['input_ids'].shape[-1]}, but should be {self.targets_length}."
|
307 |
+
)
|
308 |
+
|
309 |
+
if batch["labels"].shape[-1] != self.targets_length:
|
310 |
+
raise ValueError(
|
311 |
+
f"`labels` are incorrectly preprocessed. `labels` length is \
|
312 |
+
{batch['labels'].shape[-1]}, but should be {self.targets_length}."
|
313 |
+
)
|
314 |
+
|
315 |
+
batch["decoder_input_ids"] = self.shift_tokens_right(
|
316 |
+
batch["labels"], self.pad_token_id, self.decoder_start_token_id
|
317 |
+
)
|
318 |
+
|
319 |
+
for k, v in batch.items():
|
320 |
+
batch[k] = torch.tensor(v)
|
321 |
+
# print(k, batch[k], self.tokenizer.batch_decode(batch[k]), '\n', flush=True)
|
322 |
+
return batch
|
323 |
+
|
324 |
+
def create_sentinel_ids(self, mask_indices):
|
325 |
+
"""
|
326 |
+
Sentinel ids creation given the indices that should be masked.
|
327 |
+
The start indices of each mask are replaced by the sentinel ids in increasing
|
328 |
+
order. Consecutive mask indices to be deleted are replaced with `-1`.
|
329 |
+
"""
|
330 |
+
start_indices = mask_indices - \
|
331 |
+
np.roll(mask_indices, 1, axis=-1) * mask_indices
|
332 |
+
start_indices[:, 0] = mask_indices[:, 0]
|
333 |
+
|
334 |
+
sentinel_ids = np.where(start_indices != 0, np.cumsum(
|
335 |
+
start_indices, axis=-1), start_indices)
|
336 |
+
sentinel_ids = np.where(
|
337 |
+
sentinel_ids != 0, (self.vocab_size - sentinel_ids), 0)
|
338 |
+
sentinel_ids -= mask_indices - start_indices
|
339 |
+
|
340 |
+
return sentinel_ids
|
341 |
+
|
342 |
+
def filter_input_ids(self, input_ids, sentinel_ids):
|
343 |
+
"""
|
344 |
+
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
|
345 |
+
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
|
346 |
+
"""
|
347 |
+
batch_size = input_ids.shape[0]
|
348 |
+
|
349 |
+
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
350 |
+
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
|
351 |
+
# masked tokens coming after sentinel tokens and should be removed
|
352 |
+
input_ids = input_ids_full[input_ids_full >=
|
353 |
+
0].reshape((batch_size, -1))
|
354 |
+
input_ids = np.concatenate(
|
355 |
+
[input_ids, np.full((batch_size, 1), self.eos_token_id, dtype=np.int32)], axis=-1
|
356 |
+
)
|
357 |
+
return input_ids
|
358 |
+
|
359 |
+
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
360 |
+
def shift_tokens_right(self, input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
|
361 |
+
"""
|
362 |
+
Shift input ids one token to the right.
|
363 |
+
"""
|
364 |
+
shifted_input_ids = np.zeros_like(input_ids)
|
365 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
366 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
367 |
+
|
368 |
+
shifted_input_ids = np.where(
|
369 |
+
shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
370 |
+
return shifted_input_ids
|
371 |
+
|
372 |
+
def random_spans_noise_mask(self, length):
|
373 |
+
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/
|
374 |
+
blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
375 |
+
Noise mask consisting of random spans of noise tokens.
|
376 |
+
The number of noise tokens and the number of noise spans and non-noise spans
|
377 |
+
are determined deterministically as follows:
|
378 |
+
num_noise_tokens = round(length * noise_density)
|
379 |
+
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
|
380 |
+
Spans alternate between non-noise and noise, beginning with non-noise.
|
381 |
+
Subject to the above restrictions, all masks are equally likely.
|
382 |
+
Args:
|
383 |
+
length: an int32 scalar (length of the incoming token sequence)
|
384 |
+
noise_density: a float - approximate density of output mask
|
385 |
+
mean_noise_span_length: a number
|
386 |
+
Returns:
|
387 |
+
a boolean tensor with shape [length]
|
388 |
+
"""
|
389 |
+
|
390 |
+
orig_length = length
|
391 |
+
|
392 |
+
num_noise_tokens = int(np.round(length * self.noise_density))
|
393 |
+
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
394 |
+
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
395 |
+
num_noise_spans = int(
|
396 |
+
np.round(num_noise_tokens / self.mean_noise_span_length))
|
397 |
+
|
398 |
+
# avoid degeneracy by ensuring positive number of noise spans
|
399 |
+
num_noise_spans = max(num_noise_spans, 1)
|
400 |
+
num_nonnoise_tokens = length - num_noise_tokens
|
401 |
+
|
402 |
+
# pick the lengths of the noise spans and the non-noise spans
|
403 |
+
def _random_segmentation(num_items, num_segments):
|
404 |
+
"""Partition a sequence of items randomly into non-empty segments.
|
405 |
+
Args:
|
406 |
+
num_items: an integer scalar > 0
|
407 |
+
num_segments: an integer scalar in [1, num_items]
|
408 |
+
Returns:
|
409 |
+
a Tensor with shape [num_segments] containing positive integers that add
|
410 |
+
up to num_items
|
411 |
+
"""
|
412 |
+
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
413 |
+
np.random.shuffle(mask_indices)
|
414 |
+
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
415 |
+
segment_id = np.cumsum(first_in_segment)
|
416 |
+
# count length of sub segments assuming that list is sorted
|
417 |
+
_, segment_length = np.unique(segment_id, return_counts=True)
|
418 |
+
return segment_length
|
419 |
+
|
420 |
+
noise_span_lengths = _random_segmentation(
|
421 |
+
num_noise_tokens, num_noise_spans)
|
422 |
+
nonnoise_span_lengths = _random_segmentation(
|
423 |
+
num_nonnoise_tokens, num_noise_spans)
|
424 |
+
|
425 |
+
interleaved_span_lengths = np.reshape(
|
426 |
+
np.stack([nonnoise_span_lengths, noise_span_lengths],
|
427 |
+
axis=1), [num_noise_spans * 2]
|
428 |
+
)
|
429 |
+
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
430 |
+
span_start_indicator = np.zeros((length,), dtype=np.int8)
|
431 |
+
span_start_indicator[span_starts] = True
|
432 |
+
span_num = np.cumsum(span_start_indicator)
|
433 |
+
is_noise = np.equal(span_num % 2, 1)
|
434 |
+
|
435 |
+
return is_noise[:orig_length]
|
436 |
+
|
437 |
+
|
438 |
+
class TaskT5Dataset(Dataset):
|
439 |
+
def __init__(self, data_path, args):
|
440 |
+
super().__init__()
|
441 |
+
self.max_length = args.max_seq_length
|
442 |
+
if args.tokenizer_type == 't5_tokenizer':
|
443 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(args.pretrained_model_path)
|
444 |
+
else:
|
445 |
+
self.tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_path)
|
446 |
+
self.data = self.load_data(data_path)
|
447 |
+
|
448 |
+
def __len__(self):
|
449 |
+
return len(self.data)
|
450 |
+
|
451 |
+
def __getitem__(self, index):
|
452 |
+
return self.encode(self.data[index])
|
453 |
+
|
454 |
+
def load_data(self, data_path):
|
455 |
+
samples = []
|
456 |
+
with open(data_path, 'r', encoding='utf8') as f:
|
457 |
+
lines = f.readlines()
|
458 |
+
for line in tqdm(lines):
|
459 |
+
samples.append(json.loads(line))
|
460 |
+
return samples
|
461 |
+
|
462 |
+
def encode(self, item):
|
463 |
+
if item["textb"] != "":
|
464 |
+
text = item['question'] + ','.join(item['choice'])+'。' + f"""{item["texta"]}""" + f"""{item["textb"]}"""
|
465 |
+
else:
|
466 |
+
text = f"""{item["question"]}""" + ",".join(item["choice"]) + "。" + f"""{item["texta"]}"""
|
467 |
+
label = item['answer']
|
468 |
+
encode_dict = self.tokenizer.encode_plus(text, max_length=self.max_length, padding='max_length',
|
469 |
+
truncation=True, return_tensors='pt')
|
470 |
+
decode_dict = self.tokenizer.encode_plus(label, max_length=16, padding='max_length',
|
471 |
+
truncation=True)
|
472 |
+
|
473 |
+
answer_token = []
|
474 |
+
max_label_len = 0
|
475 |
+
choice_encode = [] # 用来确定模型生成的最大长度
|
476 |
+
for a in item['choice']:
|
477 |
+
answer_encode = self.tokenizer.encode(a)
|
478 |
+
choice_encode.append(answer_encode)
|
479 |
+
if len(answer_encode) > max_label_len:
|
480 |
+
max_label_len = len(answer_encode)
|
481 |
+
for an in answer_encode:
|
482 |
+
if an not in answer_token:
|
483 |
+
answer_token.append(an)
|
484 |
+
|
485 |
+
# bad_words_ids = [[i] for i in range(self.tokenizer.vocab_size) if i not in answer_token] #不生成这些token
|
486 |
+
|
487 |
+
# while len(bad_words_ids)<self.tokenizer.vocab_size:
|
488 |
+
# bad_words_ids.append(bad_words_ids[0])
|
489 |
+
|
490 |
+
# bad_words_ids = [[423],[67],[878]]
|
491 |
+
|
492 |
+
encode_sent = encode_dict['input_ids'].squeeze()
|
493 |
+
attention_mask = encode_dict['attention_mask'].squeeze()
|
494 |
+
target = decode_dict['input_ids']
|
495 |
+
labels = torch.tensor(target)
|
496 |
+
labels[target == self.tokenizer.pad_token_id] = -100
|
497 |
+
|
498 |
+
return {
|
499 |
+
"input_ids": torch.tensor(encode_sent).long(),
|
500 |
+
"attention_mask": torch.tensor(attention_mask).float(),
|
501 |
+
"labels": torch.tensor(target).long(),
|
502 |
+
"force_words_ids": answer_token,
|
503 |
+
}
|
504 |
+
|
505 |
+
|
506 |
+
class TaskT5DataModel(pl.LightningDataModule):
|
507 |
+
@staticmethod
|
508 |
+
def add_data_specific_args(parent_args):
|
509 |
+
parser = parent_args.add_argument_group('TaskT5DataModel')
|
510 |
+
parser.add_argument('--dataset_num_workers', default=8, type=int)
|
511 |
+
parser.add_argument('--dataloader_num_workers', default=4, type=int)
|
512 |
+
parser.add_argument(
|
513 |
+
'--train_data_path', default='wudao_180g_mt5_tokenized', type=str)
|
514 |
+
parser.add_argument(
|
515 |
+
'--valid_data_path', default='wudao_180g_mt5_tokenized', type=str)
|
516 |
+
parser.add_argument('--train_batchsize', default=2, type=int)
|
517 |
+
parser.add_argument('--valid_batchsize', default=2, type=int)
|
518 |
+
parser.add_argument('--train_split_size', default=None, type=float)
|
519 |
+
parser.add_argument('--tokenizer_type', default='t5_tokenizer', choices=['t5_tokenizer', 'bert_tokenizer'])
|
520 |
+
parser.add_argument('--text_column_name', default='text')
|
521 |
+
parser.add_argument('--remove_columns', nargs='+', default=[])
|
522 |
+
return parent_args
|
523 |
+
|
524 |
+
def __init__(self, args):
|
525 |
+
super().__init__()
|
526 |
+
self.save_hyperparameters(args)
|
527 |
+
self.train_dataset = TaskT5Dataset(args.train_data_path, args)
|
528 |
+
self.valid_dataset = TaskT5Dataset(args.valid_data_path, args)
|
529 |
+
|
530 |
+
def train_dataloader(self):
|
531 |
+
from fengshen.data.universal_datamodule.universal_sampler import PretrainingSampler
|
532 |
+
from fengshen.data.universal_datamodule.universal_datamodule import get_consume_samples
|
533 |
+
# 采用自定��的sampler,确保继续训练能正确取到数据
|
534 |
+
consumed_samples = get_consume_samples(self)
|
535 |
+
# batch_sampler = PretrainingRandomSampler(
|
536 |
+
batch_sampler = PretrainingSampler(
|
537 |
+
total_samples=len(self.train_dataset),
|
538 |
+
consumed_samples=consumed_samples,
|
539 |
+
micro_batch_size=self.hparams.train_batchsize,
|
540 |
+
data_parallel_rank=self.trainer.global_rank,
|
541 |
+
data_parallel_size=self.trainer.world_size,
|
542 |
+
)
|
543 |
+
# epoch=self.trainer.current_epoch
|
544 |
+
# )
|
545 |
+
return DataLoader(
|
546 |
+
self.train_dataset,
|
547 |
+
batch_sampler=batch_sampler,
|
548 |
+
pin_memory=True,
|
549 |
+
num_workers=self.hparams.dataloader_num_workers
|
550 |
+
)
|
551 |
+
|
552 |
+
def val_dataloader(self):
|
553 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
554 |
+
self.valid_dataset, shuffle=False)
|
555 |
+
return DataLoader(
|
556 |
+
self.valid_dataset,
|
557 |
+
sampler=sampler,
|
558 |
+
shuffle=False,
|
559 |
+
batch_size=self.hparams.valid_batchsize,
|
560 |
+
pin_memory=True,
|
561 |
+
num_workers=self.hparams.dataloader_num_workers
|
562 |
+
)
|
fengshen/data/t5_dataloader/t5_gen_datasets.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : t5_gen_datasets.py
|
4 |
+
@Time : 2022/10/24 19:29
|
5 |
+
@Author : He Junqing
|
6 |
+
@Version : 1.0
|
7 |
+
@Contact : hejunqing@idea.edu.cn
|
8 |
+
@License : (C)Copyright 2022-2023, CCNL-IDEA
|
9 |
+
'''
|
10 |
+
|
11 |
+
from logging import exception
|
12 |
+
from transformers import (
|
13 |
+
BertTokenizer,
|
14 |
+
MT5Config,
|
15 |
+
MT5Tokenizer,
|
16 |
+
MT5ForConditionalGeneration,
|
17 |
+
)
|
18 |
+
import torch
|
19 |
+
from torch.utils.data import Dataset, DataLoader
|
20 |
+
from torch.nn.utils.rnn import pad_sequence
|
21 |
+
import pytorch_lightning as pl
|
22 |
+
import numpy as np
|
23 |
+
import sys
|
24 |
+
|
25 |
+
sys.path.append("../../")
|
26 |
+
|
27 |
+
special_token_dict = {
|
28 |
+
"additional_special_tokens": [
|
29 |
+
"[CTSTART]",
|
30 |
+
"[CTEND]",
|
31 |
+
"[SEP]",
|
32 |
+
"[KNSTART]",
|
33 |
+
"[KNEND]",
|
34 |
+
]
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
class DialogDataset(Dataset):
|
39 |
+
def __init__(self, data_path, args, data, load_data_type=1) -> None:
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
if args.tokenizer_type == "t5_tokenizer":
|
43 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(
|
44 |
+
args.pretrained_model_path)
|
45 |
+
if len(self.tokenizer) == 32596:
|
46 |
+
self.tokenizer.add_special_tokens(special_token_dict)
|
47 |
+
print(
|
48 |
+
"add special tokens to tokenizer,vocab size:",
|
49 |
+
len(self.tokenizer)
|
50 |
+
)
|
51 |
+
self.model = MT5ForConditionalGeneration.from_pretrained(
|
52 |
+
args.pretrained_model_path
|
53 |
+
)
|
54 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
55 |
+
self.model.save_pretrained(args.new_vocab_path)
|
56 |
+
self.tokenizer.save_pretrained(
|
57 |
+
args.new_vocab_path)
|
58 |
+
else:
|
59 |
+
self.tokenizer = BertTokenizer.from_pretrained(
|
60 |
+
args.pretrained_model_path)
|
61 |
+
|
62 |
+
self.load_data_type = load_data_type
|
63 |
+
self.data_split = data
|
64 |
+
self.num_workers = args.preprocessing_num_workers
|
65 |
+
self.max_seq_length = args.max_seq_length
|
66 |
+
self.max_knowledge_length = args.max_knowledge_length
|
67 |
+
self.max_target_length = args.max_target_length
|
68 |
+
|
69 |
+
# tokenizer config
|
70 |
+
self.config = MT5Config.from_pretrained(args.pretrained_model_path)
|
71 |
+
self.decoder_start_token_id = self.config.decoder_start_token_id
|
72 |
+
self.eos_token_id = self.config.eos_token_id
|
73 |
+
self.vocab_size = self.config.vocab_size
|
74 |
+
# print(self.tokenizer.decode([2]))
|
75 |
+
|
76 |
+
# load from raw data or hf dataset
|
77 |
+
|
78 |
+
if self.load_data_type == 0:
|
79 |
+
self.data = self.load_data(data_path)
|
80 |
+
elif self.load_data_type == 1:
|
81 |
+
self.data = self.load_packed_data(data_path)
|
82 |
+
else: # for testing
|
83 |
+
self.data = data_path
|
84 |
+
|
85 |
+
def load_packed_data(self, data_path):
|
86 |
+
from fengshen.data.fs_datasets import load_dataset
|
87 |
+
|
88 |
+
samples = load_dataset(data_path,
|
89 |
+
num_proc=self.num_workers)[self.data_split]
|
90 |
+
tokenized_samples = samples.map(
|
91 |
+
self.regular_tokenize, batched=False,
|
92 |
+
num_proc=self.num_workers
|
93 |
+
)
|
94 |
+
|
95 |
+
return tokenized_samples
|
96 |
+
|
97 |
+
def load_data(self, data_path):
|
98 |
+
"""
|
99 |
+
load data from raw data
|
100 |
+
return untokoenized data
|
101 |
+
"""
|
102 |
+
from datasets import load_dataset
|
103 |
+
|
104 |
+
ds = load_dataset("json", data_files=data_path)['train']
|
105 |
+
samples = ds.map(self.regular_tokenize, batched=False, num_proc=self.num_workers
|
106 |
+
)
|
107 |
+
return samples
|
108 |
+
|
109 |
+
def __getitem__(self, index):
|
110 |
+
return self.data[index]
|
111 |
+
|
112 |
+
def __len__(self):
|
113 |
+
return len(self.data)
|
114 |
+
|
115 |
+
def regular_tokenize(self, sample):
|
116 |
+
# print(len(sample['context']))
|
117 |
+
context_ids = self.tokenizer(
|
118 |
+
sample["context"],
|
119 |
+
add_special_tokens=True,
|
120 |
+
return_attention_mask=False,
|
121 |
+
return_token_type_ids=True,
|
122 |
+
)
|
123 |
+
|
124 |
+
context_types = self.get_token_type(
|
125 |
+
sample["context"], context_ids["token_type_ids"]
|
126 |
+
)
|
127 |
+
# print('context',sample['context'])
|
128 |
+
# print('context_ids',context_ids['input_ids'])
|
129 |
+
knowledge_ids = self.tokenizer.encode(
|
130 |
+
sample["knowledge"], add_special_tokens=False
|
131 |
+
)
|
132 |
+
# print('knowledge_ids',knowledge_ids)
|
133 |
+
if isinstance(knowledge_ids, int):
|
134 |
+
knowledge_ids = [knowledge_ids]
|
135 |
+
target_ids = self.tokenizer.encode(
|
136 |
+
sample["target"],
|
137 |
+
add_special_tokens=False,
|
138 |
+
max_length=self.max_target_length - 1,
|
139 |
+
truncation=True,
|
140 |
+
)
|
141 |
+
# print('target',sample['target'])
|
142 |
+
# print('target_ids',target_ids)
|
143 |
+
# print('decode target',self.tokenizer.decode(target_ids))
|
144 |
+
# truncate
|
145 |
+
|
146 |
+
knowledge_ids = (
|
147 |
+
[self.tokenizer.convert_tokens_to_ids("[KNSTART]")]
|
148 |
+
+ knowledge_ids[: self.max_knowledge_length - 2]
|
149 |
+
+ [self.tokenizer.convert_tokens_to_ids("[KNEND]")]
|
150 |
+
)
|
151 |
+
l_kn = len(knowledge_ids)
|
152 |
+
knowledge_types = [2] * l_kn
|
153 |
+
|
154 |
+
flatten_context = []
|
155 |
+
for line in context_ids["input_ids"]:
|
156 |
+
flatten_context.extend(line)
|
157 |
+
l_ct = min(len(flatten_context), self.max_seq_length - l_kn - 2)
|
158 |
+
context_ids = (
|
159 |
+
[self.tokenizer.convert_tokens_to_ids("[CTSTART]")]
|
160 |
+
+ flatten_context[-l_ct:]
|
161 |
+
+ [self.tokenizer.convert_tokens_to_ids("[CTEND]")]
|
162 |
+
)
|
163 |
+
|
164 |
+
context_types = context_types[-l_ct:] + [0]
|
165 |
+
context_types.insert(0, context_types[0])
|
166 |
+
assert len(context_ids) == len(
|
167 |
+
context_types
|
168 |
+
), "len of context ids and token types unmatch, context:{},ids:{} types:{},len {}:{}".format(
|
169 |
+
sample["context"],
|
170 |
+
context_ids,
|
171 |
+
context_types,
|
172 |
+
len(context_ids),
|
173 |
+
len(context_types),
|
174 |
+
)
|
175 |
+
|
176 |
+
try:
|
177 |
+
target_ids = target_ids + [self.eos_token_id]
|
178 |
+
except exception:
|
179 |
+
print(sample["target"], target_ids, self.eos_token_id)
|
180 |
+
|
181 |
+
tokenized = {}
|
182 |
+
tokenized["input_ids"] = np.array(context_ids + knowledge_ids, dtype=np.int32)
|
183 |
+
tokenized["token_types"] = np.array(
|
184 |
+
context_types + knowledge_types, dtype=np.int32
|
185 |
+
)
|
186 |
+
tokenized["attention_mask"] = np.ones(
|
187 |
+
len(context_types + knowledge_types), dtype=np.int8
|
188 |
+
)
|
189 |
+
tokenized["labels"] = np.array(target_ids, dtype=np.int32)
|
190 |
+
|
191 |
+
return tokenized
|
192 |
+
|
193 |
+
def get_token_type(self, context, tokentypes=None):
|
194 |
+
# token_type fail in tokenizer, all zero
|
195 |
+
context_token_types = []
|
196 |
+
for i, line in enumerate(context):
|
197 |
+
if tokentypes:
|
198 |
+
if i % 2 == 0:
|
199 |
+
token_type = [0] * len(tokentypes[i])
|
200 |
+
else:
|
201 |
+
token_type = [1] * len(tokentypes[i])
|
202 |
+
else:
|
203 |
+
if i % 2 == 0:
|
204 |
+
token_type = [0] * (1 + len(line))
|
205 |
+
else:
|
206 |
+
token_type = [1] * (1 + len(line))
|
207 |
+
|
208 |
+
context_token_types.extend(token_type)
|
209 |
+
|
210 |
+
return context_token_types
|
211 |
+
|
212 |
+
|
213 |
+
class DialogDataModel(pl.LightningDataModule):
|
214 |
+
@staticmethod
|
215 |
+
def add_data_specific_args(parent_args):
|
216 |
+
parser = parent_args.add_argument_group("SuperviseT5DataModel")
|
217 |
+
parser.add_argument("--dataset_num_workers", default=8, type=int)
|
218 |
+
parser.add_argument("--dataloader_num_workers", default=4, type=int)
|
219 |
+
parser.add_argument("--train_data_path", default="dialog_4g_test", type=str)
|
220 |
+
parser.add_argument(
|
221 |
+
"--valid_data_path", default="wudao_180g_mt5_tokenized", type=str
|
222 |
+
)
|
223 |
+
parser.add_argument("--train_batchsize", default=2, type=int)
|
224 |
+
parser.add_argument("--valid_batchsize", default=2, type=int)
|
225 |
+
parser.add_argument("--max_seq_length", default=512, type=int)
|
226 |
+
parser.add_argument("--max_knowledge_length", default=128, type=int)
|
227 |
+
parser.add_argument("--max_target_length", default=128, type=int)
|
228 |
+
|
229 |
+
return parent_args
|
230 |
+
|
231 |
+
def __init__(self, args):
|
232 |
+
super().__init__()
|
233 |
+
self.save_hyperparameters(args)
|
234 |
+
self.load_data(args)
|
235 |
+
self.epochs = args.max_epochs
|
236 |
+
|
237 |
+
def load_data(self, args):
|
238 |
+
if args.train_split_size is not None:
|
239 |
+
from fengshen.data.fs_datasets import load_dataset
|
240 |
+
|
241 |
+
data_splits = load_dataset(
|
242 |
+
args.train_data_path, num_proc=args.dataset_num_workers
|
243 |
+
)
|
244 |
+
train_split = data_splits['train']
|
245 |
+
test_split = data_splits['test']
|
246 |
+
print('train:', train_split, '\ntest_data:', test_split)
|
247 |
+
self.train_dataset = DialogDataset(
|
248 |
+
args.train_data_path, args, load_data_type=1, data="train"
|
249 |
+
)
|
250 |
+
self.test_dataset = DialogDataset(
|
251 |
+
args.train_data_path, args, load_data_type=1, data="test"
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
self.train_data = DialogDataset(
|
255 |
+
args.train_data_path, args, load_data_type=1
|
256 |
+
)
|
257 |
+
|
258 |
+
self.config = MT5Config.from_pretrained(args.pretrained_model_path)
|
259 |
+
self.pad_token_id = self.config.pad_token_id
|
260 |
+
self.decoder_start_token_id = self.config.decoder_start_token_id
|
261 |
+
print("bos id:", self.decoder_start_token_id)
|
262 |
+
|
263 |
+
def collate_fn(self, samples):
|
264 |
+
batch = {
|
265 |
+
k: [
|
266 |
+
torch.tensor(samples[i][k], dtype=torch.int64)
|
267 |
+
for i in range(len(samples))
|
268 |
+
]
|
269 |
+
for k in ["input_ids", "token_types", "attention_mask", "labels"]
|
270 |
+
}
|
271 |
+
|
272 |
+
# print(batch)
|
273 |
+
for k, v in batch.items():
|
274 |
+
if k != "labels":
|
275 |
+
batch[k] = pad_sequence(
|
276 |
+
v, batch_first=True, padding_value=self.pad_token_id
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
batch[k] = pad_sequence(v, batch_first=True, padding_value=-100)
|
280 |
+
batch["decoder_input_ids"] = torch.tensor(
|
281 |
+
self.shift_tokens_right(
|
282 |
+
batch["labels"], self.pad_token_id, self.decoder_start_token_id
|
283 |
+
),
|
284 |
+
dtype=torch.long,
|
285 |
+
)
|
286 |
+
return batch
|
287 |
+
|
288 |
+
def shift_tokens_right(
|
289 |
+
self, input_ids: np.array, pad_token_id: int, decoder_start_token_id: int
|
290 |
+
) -> np.ndarray:
|
291 |
+
"""
|
292 |
+
Shift input ids one token to the right.
|
293 |
+
"""
|
294 |
+
shifted_input_ids = np.zeros_like(input_ids)
|
295 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
296 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
297 |
+
|
298 |
+
shifted_input_ids = np.where(
|
299 |
+
shifted_input_ids == -100, pad_token_id, shifted_input_ids
|
300 |
+
)
|
301 |
+
return shifted_input_ids
|
302 |
+
|
303 |
+
def train_dataloader(self):
|
304 |
+
from fengshen.data.universal_datamodule.universal_sampler import (
|
305 |
+
PretrainingRandomSampler,
|
306 |
+
)
|
307 |
+
from fengshen.data.universal_datamodule.universal_datamodule import (
|
308 |
+
get_consume_samples,
|
309 |
+
)
|
310 |
+
|
311 |
+
# 采用自定义的sampler,确保继续训练能正确取到数据
|
312 |
+
consumed_samples = get_consume_samples(self)
|
313 |
+
batch_sampler = PretrainingRandomSampler(
|
314 |
+
epoch=self.epochs,
|
315 |
+
total_samples=len(self.train_dataset),
|
316 |
+
consumed_samples=consumed_samples,
|
317 |
+
micro_batch_size=self.hparams.train_batchsize,
|
318 |
+
data_parallel_rank=self.trainer.global_rank, # gpu idx
|
319 |
+
data_parallel_size=self.trainer.world_size, # gpu num
|
320 |
+
)
|
321 |
+
return DataLoader(
|
322 |
+
self.train_dataset,
|
323 |
+
batch_sampler=batch_sampler,
|
324 |
+
pin_memory=True,
|
325 |
+
num_workers=self.hparams.dataloader_num_workers,
|
326 |
+
collate_fn=self.collate_fn,
|
327 |
+
)
|
328 |
+
|
329 |
+
def val_dataloader(self):
|
330 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
331 |
+
self.test_dataset, shuffle=False
|
332 |
+
)
|
333 |
+
return DataLoader(
|
334 |
+
self.test_dataset,
|
335 |
+
sampler=sampler,
|
336 |
+
shuffle=False,
|
337 |
+
batch_size=self.hparams.valid_batchsize,
|
338 |
+
pin_memory=True,
|
339 |
+
num_workers=self.hparams.dataloader_num_workers,
|
340 |
+
collate_fn=self.collate_fn,
|
341 |
+
)
|
342 |
+
|
343 |
+
def predict_dataloader(self):
|
344 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
345 |
+
self.test_dataset, shuffle=False
|
346 |
+
)
|
347 |
+
return DataLoader(
|
348 |
+
self.test_dataset,
|
349 |
+
sampler=sampler,
|
350 |
+
shuffle=False,
|
351 |
+
batch_size=self.hparams.valid_batchsize,
|
352 |
+
pin_memory=True,
|
353 |
+
num_workers=self.hparams.dataloader_num_workers,
|
354 |
+
collate_fn=self.collate_fn,
|
355 |
+
)
|
356 |
+
|
357 |
+
|
358 |
+
if __name__ == "__main__":
|
359 |
+
# test
|
360 |
+
import argparse
|
361 |
+
|
362 |
+
total_parser = argparse.ArgumentParser("DATASET parser")
|
363 |
+
total_parser.add_argument(
|
364 |
+
"--tokenizer_type",
|
365 |
+
default="t5_tokenizer",
|
366 |
+
choices=["bert_tokenizer", "t5_tokenizer"],
|
367 |
+
)
|
368 |
+
total_parser.add_argument("--preprocessing_num_workers", default="10", type=int)
|
369 |
+
total_parser.add_argument(
|
370 |
+
"--new_vocab_path",
|
371 |
+
default="/cognitive_comp/hejunqing/projects/Dialog_pretrain/randeng_t5_newvocab_784M",
|
372 |
+
type=str,
|
373 |
+
)
|
374 |
+
total_parser.add_argument("--train_split_size", default=0.995, type=int)
|
375 |
+
total_parser.add_argument(
|
376 |
+
"--pretrained_model_path",
|
377 |
+
default="/cognitive_comp/hejunqing/projects/Dialog_pretrain/randeng_t5_newvocab_784M",
|
378 |
+
)
|
379 |
+
total_parser = DialogDataModel.add_data_specific_args(total_parser)
|
380 |
+
args = total_parser.parse_args()
|
381 |
+
dl = DialogDataModel(args)
|
382 |
+
|
383 |
+
for i in range(5):
|
384 |
+
for batch in dl.train_dataloader():
|
385 |
+
print(batch)
|
386 |
+
print(batch["input_ids"])
|
387 |
+
print(batch["token_types"])
|
388 |
+
print(batch["decoder_input_ids"])
|
389 |
+
print(batch["labels"])
|
390 |
+
|
391 |
+
print("test finish")
|
fengshen/data/taiyi_stable_diffusion_datasets/taiyi_datasets.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset, ConcatDataset
|
2 |
+
import os
|
3 |
+
from concurrent.futures import ProcessPoolExecutor
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
def add_data_args(parent_args):
|
8 |
+
parser = parent_args.add_argument_group('taiyi stable diffusion data args')
|
9 |
+
# 支持传入多个路径,分别加载
|
10 |
+
parser.add_argument(
|
11 |
+
"--datasets_path", type=str, default=None, required=True, nargs='+',
|
12 |
+
help="A folder containing the training data of instance images.",
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--datasets_type", type=str, default=None, required=True, choices=['txt', 'csv', 'fs_datasets'], nargs='+',
|
16 |
+
help="dataset type, txt or csv, same len as datasets_path",
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--resolution", type=int, default=512,
|
20 |
+
help=(
|
21 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
22 |
+
" resolution"
|
23 |
+
),
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--center_crop", action="store_true", default=False,
|
27 |
+
help="Whether to center crop images before resizing to resolution"
|
28 |
+
)
|
29 |
+
parser.add_argument("--thres", type=float, default=0.2)
|
30 |
+
return parent_args
|
31 |
+
|
32 |
+
|
33 |
+
class TXTDataset(Dataset):
|
34 |
+
# 添加Txt数据集读取,主要是针对Zero23m数据集。
|
35 |
+
def __init__(self,
|
36 |
+
foloder_name,
|
37 |
+
thres=0.2):
|
38 |
+
super().__init__()
|
39 |
+
# print(f'Loading folder data from {foloder_name}.')
|
40 |
+
self.image_paths = []
|
41 |
+
'''
|
42 |
+
暂时没有开源这部分文件
|
43 |
+
score_data = pd.read_csv(os.path.join(foloder_name, 'score.csv'))
|
44 |
+
img_path2score = {score_data['image_path'][i]: score_data['score'][i]
|
45 |
+
for i in range(len(score_data))}
|
46 |
+
'''
|
47 |
+
# print(img_path2score)
|
48 |
+
# 这里都存的是地址,避免初始化时间过多。
|
49 |
+
for each_file in os.listdir(foloder_name):
|
50 |
+
if each_file.endswith('.jpg'):
|
51 |
+
self.image_paths.append(os.path.join(foloder_name, each_file))
|
52 |
+
|
53 |
+
# print('Done loading data. Len of images:', len(self.image_paths))
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.image_paths)
|
57 |
+
|
58 |
+
def __getitem__(self, idx):
|
59 |
+
img_path = str(self.image_paths[idx])
|
60 |
+
caption_path = img_path.replace('.jpg', '.txt') # 图片名称和文本名称一致。
|
61 |
+
with open(caption_path, 'r') as f:
|
62 |
+
caption = f.read()
|
63 |
+
return {'img_path': img_path, 'caption': caption}
|
64 |
+
|
65 |
+
|
66 |
+
# NOTE 加速读取数据,直接用原版的,在外部使用并行读取策略。30min->3min
|
67 |
+
class CSVDataset(Dataset):
|
68 |
+
def __init__(self,
|
69 |
+
input_filename,
|
70 |
+
image_root,
|
71 |
+
img_key,
|
72 |
+
caption_key,
|
73 |
+
thres=0.2):
|
74 |
+
super().__init__()
|
75 |
+
# logging.debug(f'Loading csv data from {input_filename}.')
|
76 |
+
print(f'Loading csv data from {input_filename}.')
|
77 |
+
self.images = []
|
78 |
+
self.captions = []
|
79 |
+
|
80 |
+
if input_filename.endswith('.csv'):
|
81 |
+
# print(f"Load Data from{input_filename}")
|
82 |
+
df = pd.read_csv(input_filename, index_col=0, on_bad_lines='skip')
|
83 |
+
print(f'file {input_filename} datalen {len(df)}')
|
84 |
+
# 这个图片的路径也需要根据数据集的结构稍微做点修改
|
85 |
+
self.images.extend(df[img_key].tolist())
|
86 |
+
self.captions.extend(df[caption_key].tolist())
|
87 |
+
self.image_root = image_root
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.images)
|
91 |
+
|
92 |
+
def __getitem__(self, idx):
|
93 |
+
img_path = os.path.join(self.image_root, str(self.images[idx]))
|
94 |
+
return {'img_path': img_path, 'caption': self.captions[idx]}
|
95 |
+
|
96 |
+
|
97 |
+
def if_final_dir(path: str) -> bool:
|
98 |
+
# 如果当前目录有一个文件,那就算是终极目录
|
99 |
+
for f in os.scandir(path):
|
100 |
+
if f.is_file():
|
101 |
+
return True
|
102 |
+
return False
|
103 |
+
|
104 |
+
|
105 |
+
def process_pool_read_txt_dataset(args,
|
106 |
+
input_root=None,
|
107 |
+
thres=0.2):
|
108 |
+
p = ProcessPoolExecutor(max_workers=20)
|
109 |
+
all_datasets = []
|
110 |
+
res = []
|
111 |
+
|
112 |
+
# 遍历该目录下所有的子目录
|
113 |
+
def traversal_files(path: str):
|
114 |
+
list_subfolders_with_paths = [f.path for f in os.scandir(path) if f.is_dir()]
|
115 |
+
for dir_path in list_subfolders_with_paths:
|
116 |
+
if if_final_dir(dir_path):
|
117 |
+
res.append(p.submit(TXTDataset,
|
118 |
+
dir_path,
|
119 |
+
thres))
|
120 |
+
else:
|
121 |
+
traversal_files(dir_path)
|
122 |
+
traversal_files(input_root)
|
123 |
+
p.shutdown()
|
124 |
+
for future in res:
|
125 |
+
all_datasets.append(future.result())
|
126 |
+
dataset = ConcatDataset(all_datasets)
|
127 |
+
return dataset
|
128 |
+
|
129 |
+
|
130 |
+
def process_pool_read_csv_dataset(args,
|
131 |
+
input_root,
|
132 |
+
thres=0.20):
|
133 |
+
# here input_filename is a directory containing a CSV file
|
134 |
+
all_csvs = os.listdir(os.path.join(input_root, 'release'))
|
135 |
+
image_root = os.path.join(input_root, 'images')
|
136 |
+
# csv_with_score = [each for each in all_csvs if 'score' in each]
|
137 |
+
all_datasets = []
|
138 |
+
res = []
|
139 |
+
p = ProcessPoolExecutor(max_workers=150)
|
140 |
+
for path in all_csvs:
|
141 |
+
each_csv_path = os.path.join(input_root, 'release', path)
|
142 |
+
res.append(p.submit(CSVDataset,
|
143 |
+
each_csv_path,
|
144 |
+
image_root,
|
145 |
+
img_key="name",
|
146 |
+
caption_key="caption",
|
147 |
+
thres=thres))
|
148 |
+
p.shutdown()
|
149 |
+
for future in res:
|
150 |
+
all_datasets.append(future.result())
|
151 |
+
dataset = ConcatDataset(all_datasets)
|
152 |
+
return dataset
|
153 |
+
|
154 |
+
|
155 |
+
def load_data(args, global_rank=0):
|
156 |
+
assert len(args.datasets_path) == len(args.datasets_type), \
|
157 |
+
"datasets_path num not equal to datasets_type"
|
158 |
+
all_datasets = []
|
159 |
+
for path, type in zip(args.datasets_path, args.datasets_type):
|
160 |
+
if type == 'txt':
|
161 |
+
all_datasets.append(process_pool_read_txt_dataset(
|
162 |
+
args, input_root=path, thres=args.thres))
|
163 |
+
elif type == 'csv':
|
164 |
+
all_datasets.append(process_pool_read_csv_dataset(
|
165 |
+
args, input_root=path, thres=args.thres))
|
166 |
+
elif type == 'fs_datasets':
|
167 |
+
from fengshen.data.fs_datasets import load_dataset
|
168 |
+
all_datasets.append(load_dataset(path, num_proc=args.num_workers,
|
169 |
+
thres=args.thres, global_rank=global_rank)['train'])
|
170 |
+
else:
|
171 |
+
raise ValueError('unsupport dataset type: %s' % type)
|
172 |
+
print(f'load datasset {type} {path} len {len(all_datasets[-1])}')
|
173 |
+
return {'train': ConcatDataset(all_datasets)}
|
fengshen/data/task_dataloader/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from .task_datasets import LCSTSDataModel, LCSTSDataset
|
3 |
+
__all__ = ['LCSTSDataModel', 'LCSTSDataset']
|
fengshen/data/task_dataloader/medicalQADataset.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf8
|
2 |
+
import os
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.utils.data import DataLoader, Dataset
|
5 |
+
from tqdm import tqdm
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
|
8 |
+
|
9 |
+
class GPT2QADataset(Dataset):
|
10 |
+
'''
|
11 |
+
Dataset Used for yuyuan medical qa task.
|
12 |
+
Just surpport small datasets, when deal with large datasets it may be slowly.
|
13 |
+
for large datasets please use mmapdatasets(doing)
|
14 |
+
'''
|
15 |
+
|
16 |
+
def __init__(self, data_path, name, args):
|
17 |
+
super().__init__()
|
18 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
19 |
+
args.pretrained_model_path)
|
20 |
+
if self.tokenizer.pad_token is None:
|
21 |
+
self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})
|
22 |
+
self.data_size = os.path.getsize(data_path)/1024/1024/1024
|
23 |
+
self.data_type_name = name
|
24 |
+
self.data = self.load_data(data_path)
|
25 |
+
self.max_seq_length = args.max_seq_length
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.data)
|
29 |
+
|
30 |
+
def __getitem__(self, index):
|
31 |
+
return self.encode(self.data[index])
|
32 |
+
|
33 |
+
def load_data(self, data_path):
|
34 |
+
# 有进度条展示
|
35 |
+
if self.data_size <= 5:
|
36 |
+
with open(data_path, "rt", encoding='utf8') as f:
|
37 |
+
lines = f.readlines()
|
38 |
+
total_num = len(lines)
|
39 |
+
data_gen = lines
|
40 |
+
else:
|
41 |
+
data_gen = open(data_path, "rt", encoding='utf8')
|
42 |
+
total_num = None
|
43 |
+
|
44 |
+
data = []
|
45 |
+
with tqdm(total=total_num, desc=f'{self.data_type_name}处理进度', mininterval=0.3) as bar:
|
46 |
+
for idx, line in enumerate(data_gen):
|
47 |
+
data.append(self.data_parse(line))
|
48 |
+
bar.update()
|
49 |
+
|
50 |
+
if self.data_size > 5:
|
51 |
+
data_gen.close()
|
52 |
+
return data
|
53 |
+
|
54 |
+
def data_parse(self, line):
|
55 |
+
"""
|
56 |
+
解析不同格式的数据
|
57 |
+
"""
|
58 |
+
dic = eval(line.strip())
|
59 |
+
return dic
|
60 |
+
|
61 |
+
def encode(self, item):
|
62 |
+
"""
|
63 |
+
将数据转换成模型训练的输入
|
64 |
+
"""
|
65 |
+
inputs_dict = self.tokenizer.encode_plus(item['Question']+item['answer'],
|
66 |
+
max_length=self.max_seq_length, padding='max_length',
|
67 |
+
truncation=True, return_tensors='pt')
|
68 |
+
target = inputs_dict['input_ids']
|
69 |
+
labels = target.clone().detach()
|
70 |
+
labels[target == self.tokenizer.pad_token_id] = -100
|
71 |
+
return {
|
72 |
+
"input_ids": inputs_dict['input_ids'].squeeze(),
|
73 |
+
"attention_mask": inputs_dict['attention_mask'].squeeze(),
|
74 |
+
"labels": labels.squeeze(),
|
75 |
+
"question": item['Question'],
|
76 |
+
"answer": item['answer']
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
class GPT2QADataModel(pl.LightningDataModule):
|
81 |
+
@staticmethod
|
82 |
+
def add_data_specific_args(parent_args):
|
83 |
+
parser = parent_args.add_argument_group('GPT2QADataModel')
|
84 |
+
parser.add_argument('--data_dir', type=str, required=True)
|
85 |
+
parser.add_argument('--num_workers', default=2, type=int)
|
86 |
+
parser.add_argument('--train_data', default='train.txt', type=str)
|
87 |
+
parser.add_argument('--valid_data', default='valid.txt', type=str)
|
88 |
+
parser.add_argument('--test_data', default='test.txt', type=str)
|
89 |
+
parser.add_argument('--train_batchsize', type=int, required=True)
|
90 |
+
parser.add_argument('--valid_batchsize', type=int, required=True)
|
91 |
+
parser.add_argument('--max_seq_length', default=1024, type=int)
|
92 |
+
return parent_args
|
93 |
+
|
94 |
+
def __init__(self, args):
|
95 |
+
super().__init__()
|
96 |
+
self.args = args
|
97 |
+
self.train_batchsize = args.train_batchsize
|
98 |
+
self.valid_batchsize = args.valid_batchsize
|
99 |
+
if not args.do_eval_only:
|
100 |
+
self.train_data = GPT2QADataset(os.path.join(
|
101 |
+
args.data_dir, args.train_data), '训练集', args)
|
102 |
+
self.valid_data = GPT2QADataset(os.path.join(
|
103 |
+
args.data_dir, args.valid_data), '验证集', args)
|
104 |
+
self.test_data = GPT2QADataset(os.path.join(
|
105 |
+
args.data_dir, args.test_data), '测试集', args)
|
106 |
+
|
107 |
+
def train_dataloader(self):
|
108 |
+
return DataLoader(
|
109 |
+
self.train_data, shuffle=True,
|
110 |
+
batch_size=self.train_batchsize,
|
111 |
+
pin_memory=False, num_workers=self.args.num_workers)
|
112 |
+
|
113 |
+
def val_dataloader(self):
|
114 |
+
return DataLoader(self.valid_data, shuffle=False,
|
115 |
+
batch_size=self.valid_batchsize,
|
116 |
+
pin_memory=False, num_workers=self.args.num_workers)
|
117 |
+
|
118 |
+
def predict_dataloader(self):
|
119 |
+
return DataLoader(self.test_data, shuffle=False,
|
120 |
+
batch_size=self.valid_batchsize, pin_memory=False,
|
121 |
+
num_workers=self.args.num_workers)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
import argparse
|
126 |
+
modelfile = '/cognitive_comp/wuziwei/pretrained_model_hf/medical_v2'
|
127 |
+
datafile = '/cognitive_comp/wuziwei/task-data/medical_qa/medical_qa_train.txt'
|
128 |
+
parser = argparse.ArgumentParser(description='hf test', allow_abbrev=False)
|
129 |
+
group = parser.add_argument_group(title='test args')
|
130 |
+
group.add_argument('--pretrained-model-path', type=str, default=modelfile,
|
131 |
+
help='Number of transformer layers.')
|
132 |
+
group.add_argument('--max-seq-length', type=int, default=1024)
|
133 |
+
args = parser.parse_args()
|
134 |
+
|
135 |
+
testml = GPT2QADataset(datafile, 'medical_qa', args=args)
|
136 |
+
|
137 |
+
print(testml[10])
|
fengshen/data/task_dataloader/task_datasets.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf8
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
from tqdm import tqdm
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
import os
|
9 |
+
|
10 |
+
|
11 |
+
class AbstractCollator:
|
12 |
+
"""
|
13 |
+
collector for summary task
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, tokenizer, max_enc_length, max_dec_length, prompt):
|
17 |
+
self.tokenizer = tokenizer
|
18 |
+
self.max_enc_length = max_enc_length
|
19 |
+
self.max_dec_length = max_dec_length
|
20 |
+
self.prompt = prompt
|
21 |
+
|
22 |
+
def __call__(self, samples):
|
23 |
+
|
24 |
+
labels = []
|
25 |
+
attn_mask = []
|
26 |
+
# decoder_attn_mask = []
|
27 |
+
source_inputs = []
|
28 |
+
for sample in samples:
|
29 |
+
encode_dict = self.tokenizer.encode_plus(
|
30 |
+
self.prompt + sample['text'],
|
31 |
+
max_length=self.max_enc_length,
|
32 |
+
padding='max_length',
|
33 |
+
truncation=True,
|
34 |
+
return_tensors='pt')
|
35 |
+
decode_dict = self.tokenizer.encode_plus(
|
36 |
+
sample['summary'],
|
37 |
+
max_length=self.max_dec_length,
|
38 |
+
padding='max_length',
|
39 |
+
truncation=True,
|
40 |
+
return_tensors='pt')
|
41 |
+
source_inputs.append(encode_dict['input_ids'].squeeze())
|
42 |
+
labels.append(decode_dict['input_ids'].squeeze())
|
43 |
+
attn_mask.append(encode_dict['attention_mask'].squeeze())
|
44 |
+
# decoder_attn_mask.append(decode_dict['attention_mask'].squeeze())
|
45 |
+
# labels = torch.tensor(decode_dict['input'])
|
46 |
+
|
47 |
+
source_inputs = torch.stack(source_inputs)
|
48 |
+
labels = torch.stack(labels)
|
49 |
+
attn_mask = torch.stack(attn_mask)
|
50 |
+
# decoder_attn_mask = torch.stack(decoder_attn_mask)
|
51 |
+
# decode_input_idxs = shift_tokens_right(labels, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id)
|
52 |
+
end_token_index = torch.where(labels == self.tokenizer.eos_token_id)[1]
|
53 |
+
for idx, end_idx in enumerate(end_token_index):
|
54 |
+
labels[idx][end_idx + 1:] = -100
|
55 |
+
|
56 |
+
return {
|
57 |
+
"input_ids": source_inputs,
|
58 |
+
"attention_mask": attn_mask,
|
59 |
+
"labels": labels,
|
60 |
+
"text": [sample['text'] for sample in samples],
|
61 |
+
"summary": [sample['summary'] for sample in samples]
|
62 |
+
}
|
63 |
+
|
64 |
+
|
65 |
+
class LCSTSDataset(Dataset):
|
66 |
+
'''
|
67 |
+
Dataset Used for LCSTS summary task.
|
68 |
+
'''
|
69 |
+
|
70 |
+
def __init__(self, data_path, args):
|
71 |
+
super().__init__()
|
72 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
73 |
+
args.pretrained_model_path, use_fast=False)
|
74 |
+
self.data = self.load_data(data_path)
|
75 |
+
self.prompt = args.prompt
|
76 |
+
self.max_enc_length = args.max_enc_length
|
77 |
+
self.max_dec_length = args.max_dec_length
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.data)
|
81 |
+
|
82 |
+
def __getitem__(self, index):
|
83 |
+
return self.encode(self.data[index])
|
84 |
+
|
85 |
+
def load_data(self, data_path):
|
86 |
+
with open(data_path, "r", encoding='utf8') as f:
|
87 |
+
lines = f.readlines()
|
88 |
+
samples = []
|
89 |
+
for line in tqdm(lines):
|
90 |
+
obj = json.loads(line)
|
91 |
+
source = obj['text']
|
92 |
+
target = obj['summary']
|
93 |
+
samples.append({
|
94 |
+
"text": source,
|
95 |
+
"summary": target
|
96 |
+
})
|
97 |
+
return samples
|
98 |
+
|
99 |
+
def cal_data(self, data_path):
|
100 |
+
with open(data_path, "r", encoding='utf8') as f:
|
101 |
+
lines = f.readlines()
|
102 |
+
samples = []
|
103 |
+
enc_sizes = []
|
104 |
+
dec_sizes = []
|
105 |
+
for line in tqdm(lines):
|
106 |
+
obj = json.loads(line.strip())
|
107 |
+
source = obj['text']
|
108 |
+
target = obj['summary']
|
109 |
+
enc_input_ids = self.tokenizer.encode(source)
|
110 |
+
target = self.tokenizer.encode(target)
|
111 |
+
enc_sizes.append(len(enc_input_ids))
|
112 |
+
dec_sizes.append(len(target)-1)
|
113 |
+
samples.append({
|
114 |
+
"enc_input_ids": enc_input_ids,
|
115 |
+
"dec_input_ids": target[:-1],
|
116 |
+
"label_ids": target[1:]
|
117 |
+
})
|
118 |
+
max_enc_len = max(enc_sizes)
|
119 |
+
max_dec_len = max(dec_sizes)
|
120 |
+
import numpy as np
|
121 |
+
# mean of len(enc_input_ids): 74.68041911345998
|
122 |
+
# mean of len(dec_input_ids): 14.02265483791283
|
123 |
+
# max of len(enc_input_ids): 132
|
124 |
+
# max of len(dec_input_ids): 31
|
125 |
+
print('mean of len(enc_input_ids):', np.mean(enc_sizes),
|
126 |
+
'mean of len(dec_input_ids):', np.mean(dec_sizes),
|
127 |
+
'max of len(enc_input_ids):', max_enc_len,
|
128 |
+
'max of len(dec_input_ids):', max_dec_len)
|
129 |
+
return samples
|
130 |
+
|
131 |
+
def encode(self, item):
|
132 |
+
encode_dict = self.tokenizer.encode_plus(
|
133 |
+
self.prompt + item['text'],
|
134 |
+
max_length=self.max_enc_length,
|
135 |
+
padding='max_length',
|
136 |
+
truncation=True,
|
137 |
+
return_tensors='pt')
|
138 |
+
decode_dict = self.tokenizer.encode_plus(
|
139 |
+
item['summary'],
|
140 |
+
max_length=self.max_dec_length,
|
141 |
+
padding='max_length',
|
142 |
+
truncation=True)
|
143 |
+
|
144 |
+
target = decode_dict['input_ids']
|
145 |
+
# print('encode_dict shape:', encode_dict['input_ids'].shape)
|
146 |
+
labels = torch.tensor(target)
|
147 |
+
labels[target == self.tokenizer.pad_token_id] = -100
|
148 |
+
return {
|
149 |
+
"input_ids": encode_dict['input_ids'].squeeze(),
|
150 |
+
"attention_mask": encode_dict['attention_mask'].squeeze(),
|
151 |
+
"labels": labels.squeeze(),
|
152 |
+
"text": item['text'],
|
153 |
+
"summary": item['summary']
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
class LCSTSDataModel(pl.LightningDataModule):
|
158 |
+
@staticmethod
|
159 |
+
def add_data_specific_args(parent_args):
|
160 |
+
parser = parent_args.add_argument_group('LCSTSDataModel')
|
161 |
+
parser.add_argument(
|
162 |
+
'--data_dir', default='/cognitive_comp/ganruyi/data_datasets_LCSTS_LCSTS/', type=str)
|
163 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
164 |
+
parser.add_argument('--train_data', default='train.jsonl', type=str)
|
165 |
+
parser.add_argument('--valid_data', default='valid.jsonl', type=str)
|
166 |
+
parser.add_argument('--test_data', default='test_public.jsonl', type=str)
|
167 |
+
parser.add_argument('--train_batchsize', default=128, type=int)
|
168 |
+
parser.add_argument('--valid_batchsize', default=128, type=int)
|
169 |
+
parser.add_argument('--max_enc_length', default=128, type=int)
|
170 |
+
parser.add_argument('--max_dec_length', default=30, type=int)
|
171 |
+
parser.add_argument('--prompt', default='summarize:', type=str)
|
172 |
+
return parent_args
|
173 |
+
|
174 |
+
def __init__(self, args):
|
175 |
+
super().__init__()
|
176 |
+
self.args = args
|
177 |
+
self.train_batchsize = args.train_batchsize
|
178 |
+
self.valid_batchsize = args.valid_batchsize
|
179 |
+
if not args.do_eval_only:
|
180 |
+
self.train_data = LCSTSDataset(os.path.join(
|
181 |
+
args.data_dir, args.train_data), args)
|
182 |
+
self.valid_data = LCSTSDataset(os.path.join(
|
183 |
+
args.data_dir, args.valid_data), args)
|
184 |
+
self.test_data = LCSTSDataset(os.path.join(
|
185 |
+
args.data_dir, args.test_data), args)
|
186 |
+
|
187 |
+
def train_dataloader(self):
|
188 |
+
return DataLoader(self.train_data,
|
189 |
+
shuffle=True,
|
190 |
+
batch_size=self.train_batchsize,
|
191 |
+
pin_memory=False,
|
192 |
+
num_workers=self.args.num_workers)
|
193 |
+
|
194 |
+
def val_dataloader(self):
|
195 |
+
return DataLoader(self.valid_data,
|
196 |
+
shuffle=False,
|
197 |
+
batch_size=self.valid_batchsize,
|
198 |
+
pin_memory=False,
|
199 |
+
num_workers=self.args.num_workers)
|
200 |
+
|
201 |
+
def predict_dataloader(self):
|
202 |
+
return DataLoader(self.test_data,
|
203 |
+
shuffle=False,
|
204 |
+
batch_size=self.valid_batchsize,
|
205 |
+
pin_memory=False,
|
206 |
+
num_workers=self.args.num_workers)
|
fengshen/data/universal_datamodule/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .universal_datamodule import UniversalDataModule
|
2 |
+
from .universal_sampler import PretrainingSampler, PretrainingRandomSampler
|
3 |
+
|
4 |
+
__all__ = ['UniversalDataModule', 'PretrainingSampler', 'PretrainingRandomSampler']
|
fengshen/data/universal_datamodule/universal_datamodule.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning import LightningDataModule
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
5 |
+
|
6 |
+
|
7 |
+
def get_consume_samples(data_model: LightningDataModule) -> int:
|
8 |
+
if hasattr(data_model.trainer.lightning_module, 'consumed_samples'):
|
9 |
+
consumed_samples = data_model.trainer.lightning_module.consumed_samples
|
10 |
+
print('get consumed samples from model: {}'.format(consumed_samples))
|
11 |
+
else:
|
12 |
+
world_size = data_model.trainer.world_size
|
13 |
+
consumed_samples = max(0, data_model.trainer.global_step - 1) * \
|
14 |
+
data_model.hparams.train_batchsize * world_size * data_model.trainer.accumulate_grad_batches
|
15 |
+
print('calculate consumed samples: {}'.format(consumed_samples))
|
16 |
+
return consumed_samples
|
17 |
+
|
18 |
+
|
19 |
+
class UniversalDataModule(LightningDataModule):
|
20 |
+
@ staticmethod
|
21 |
+
def add_data_specific_args(parent_args):
|
22 |
+
parser = parent_args.add_argument_group('Universal DataModule')
|
23 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
24 |
+
parser.add_argument('--dataloader_workers', default=2, type=int)
|
25 |
+
parser.add_argument('--train_batchsize', default=16, type=int)
|
26 |
+
parser.add_argument('--val_batchsize', default=16, type=int)
|
27 |
+
parser.add_argument('--test_batchsize', default=16, type=int)
|
28 |
+
parser.add_argument('--datasets_name', type=str, default=None)
|
29 |
+
parser.add_argument('--train_datasets_field', type=str, default='train')
|
30 |
+
parser.add_argument('--val_datasets_field', type=str, default='validation')
|
31 |
+
parser.add_argument('--test_datasets_field', type=str, default='test')
|
32 |
+
parser.add_argument('--train_file', type=str, default=None)
|
33 |
+
parser.add_argument('--val_file', type=str, default=None)
|
34 |
+
parser.add_argument('--test_file', type=str, default=None)
|
35 |
+
parser.add_argument('--raw_file_type', type=str, default='json')
|
36 |
+
parser.add_argument('--sampler_type', type=str,
|
37 |
+
choices=['single',
|
38 |
+
'random'],
|
39 |
+
default='random')
|
40 |
+
return parent_args
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
tokenizer,
|
45 |
+
collate_fn,
|
46 |
+
args,
|
47 |
+
datasets=None,
|
48 |
+
**kwargs,
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
# 如果不传入datasets的名字,则可以在对象外部替换内部的datasets为模型需要的
|
52 |
+
if datasets is not None:
|
53 |
+
self.datasets = datasets
|
54 |
+
elif args.datasets_name is not None:
|
55 |
+
from fengshen.data.fs_datasets import load_dataset
|
56 |
+
print('---------begin to load datasets {}'.format(args.datasets_name))
|
57 |
+
self.datasets = load_dataset(
|
58 |
+
args.datasets_name, num_proc=args.num_workers)
|
59 |
+
print('---------ending load datasets {}'.format(args.datasets_name))
|
60 |
+
else:
|
61 |
+
print('---------begin to load datasets from local file')
|
62 |
+
from datasets import load_dataset
|
63 |
+
self.datasets = load_dataset(args.raw_file_type,
|
64 |
+
data_files={
|
65 |
+
args.train_datasets_field: args.train_file,
|
66 |
+
args.val_datasets_field: args.val_file,
|
67 |
+
args.test_datasets_field: args.test_file})
|
68 |
+
print('---------end to load datasets from local file')
|
69 |
+
|
70 |
+
self.tokenizer = tokenizer
|
71 |
+
self.collate_fn = collate_fn
|
72 |
+
self.save_hyperparameters(args)
|
73 |
+
|
74 |
+
def get_custom_sampler(self, ds):
|
75 |
+
from .universal_sampler import PretrainingRandomSampler
|
76 |
+
from .universal_sampler import PretrainingSampler
|
77 |
+
world_size = self.trainer.world_size
|
78 |
+
consumed_samples = get_consume_samples(self)
|
79 |
+
# use the user default sampler
|
80 |
+
if self.hparams.sampler_type == 'random':
|
81 |
+
return PretrainingRandomSampler(
|
82 |
+
total_samples=len(ds),
|
83 |
+
# consumed_samples cal by global steps
|
84 |
+
consumed_samples=consumed_samples,
|
85 |
+
micro_batch_size=self.hparams.train_batchsize,
|
86 |
+
data_parallel_rank=self.trainer.global_rank,
|
87 |
+
data_parallel_size=world_size,
|
88 |
+
epoch=self.trainer.current_epoch,
|
89 |
+
)
|
90 |
+
elif self.hparams.sampler_type == 'single':
|
91 |
+
return PretrainingSampler(
|
92 |
+
total_samples=len(ds),
|
93 |
+
# consumed_samples cal by global steps
|
94 |
+
consumed_samples=consumed_samples,
|
95 |
+
micro_batch_size=self.hparams.train_batchsize,
|
96 |
+
data_parallel_rank=self.trainer.global_rank,
|
97 |
+
data_parallel_size=world_size,
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
raise Exception('Unknown sampler type: {}'.format(self.hparams.sampler_type))
|
101 |
+
|
102 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
103 |
+
return
|
104 |
+
|
105 |
+
def train_dataloader(self):
|
106 |
+
ds = self.datasets[self.hparams.train_datasets_field]
|
107 |
+
|
108 |
+
collate_fn = self.collate_fn
|
109 |
+
if hasattr(ds, 'collate_fn'):
|
110 |
+
collate_fn = ds.collate_fn
|
111 |
+
|
112 |
+
if self.hparams.replace_sampler_ddp is False:
|
113 |
+
return DataLoader(
|
114 |
+
ds,
|
115 |
+
batch_sampler=self.get_custom_sampler(ds),
|
116 |
+
num_workers=self.hparams.dataloader_workers,
|
117 |
+
collate_fn=collate_fn,
|
118 |
+
pin_memory=True,
|
119 |
+
)
|
120 |
+
return DataLoader(
|
121 |
+
ds,
|
122 |
+
batch_size=self.hparams.train_batchsize,
|
123 |
+
num_workers=self.hparams.dataloader_workers,
|
124 |
+
collate_fn=collate_fn,
|
125 |
+
pin_memory=True,
|
126 |
+
)
|
127 |
+
|
128 |
+
def val_dataloader(self):
|
129 |
+
ds = self.datasets[self.hparams.val_datasets_field]
|
130 |
+
collate_fn = self.collate_fn
|
131 |
+
if hasattr(ds, 'collate_fn'):
|
132 |
+
collate_fn = ds.collate_fn
|
133 |
+
|
134 |
+
return DataLoader(
|
135 |
+
ds,
|
136 |
+
batch_size=self.hparams.val_batchsize,
|
137 |
+
shuffle=False,
|
138 |
+
num_workers=self.hparams.dataloader_workers,
|
139 |
+
collate_fn=collate_fn,
|
140 |
+
sampler=DistributedSampler(
|
141 |
+
ds, shuffle=False),
|
142 |
+
pin_memory=True,
|
143 |
+
)
|
144 |
+
|
145 |
+
# return DataLoader(
|
146 |
+
# ds, shuffle=False, batch_size=self.hparams.val_batchsize, pin_memory=False, collate_fn=collate_fn,
|
147 |
+
# )
|
148 |
+
|
149 |
+
def test_dataloader(self):
|
150 |
+
ds = self.datasets[self.hparams.test_datasets_field]
|
151 |
+
|
152 |
+
collate_fn = self.collate_fn
|
153 |
+
if collate_fn is None and hasattr(ds, 'collater'):
|
154 |
+
collate_fn = ds.collater
|
155 |
+
|
156 |
+
return DataLoader(
|
157 |
+
ds,
|
158 |
+
batch_size=self.hparams.test_batchsize,
|
159 |
+
shuffle=False,
|
160 |
+
num_workers=self.hparams.dataloader_workers,
|
161 |
+
collate_fn=collate_fn,
|
162 |
+
sampler=DistributedSampler(
|
163 |
+
ds, shuffle=False),
|
164 |
+
pin_memory=True,
|
165 |
+
)
|
fengshen/data/universal_datamodule/universal_sampler.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Dataloaders."""
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
|
22 |
+
class PretrainingSampler:
|
23 |
+
|
24 |
+
def __init__(self, total_samples, consumed_samples, micro_batch_size,
|
25 |
+
data_parallel_rank, data_parallel_size, drop_last=True):
|
26 |
+
# Keep a copy of input params for later use.
|
27 |
+
self.total_samples = total_samples
|
28 |
+
self.consumed_samples = consumed_samples
|
29 |
+
self.micro_batch_size = micro_batch_size
|
30 |
+
self.data_parallel_rank = data_parallel_rank
|
31 |
+
self.micro_batch_times_data_parallel_size = \
|
32 |
+
self.micro_batch_size * data_parallel_size
|
33 |
+
self.drop_last = drop_last
|
34 |
+
|
35 |
+
# Sanity checks.
|
36 |
+
assert self.total_samples > 0, \
|
37 |
+
'no sample to consume: {}'.format(self.total_samples)
|
38 |
+
assert self.consumed_samples < self.total_samples, \
|
39 |
+
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
40 |
+
self.total_samples)
|
41 |
+
assert self.micro_batch_size > 0
|
42 |
+
assert data_parallel_size > 0
|
43 |
+
assert self.data_parallel_rank < data_parallel_size, \
|
44 |
+
'data_parallel_rank should be smaller than data size: {}, ' \
|
45 |
+
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
46 |
+
|
47 |
+
def __len__(self):
|
48 |
+
return self.total_samples // self.micro_batch_times_data_parallel_size
|
49 |
+
|
50 |
+
def get_start_end_idx(self):
|
51 |
+
start_idx = self.data_parallel_rank * self.micro_batch_size
|
52 |
+
end_idx = start_idx + self.micro_batch_size
|
53 |
+
return start_idx, end_idx
|
54 |
+
|
55 |
+
def __iter__(self):
|
56 |
+
batch = []
|
57 |
+
# Last batch will be dropped if drop_last is not set False
|
58 |
+
for idx in range(self.consumed_samples, self.total_samples):
|
59 |
+
batch.append(idx)
|
60 |
+
if len(batch) == self.micro_batch_times_data_parallel_size:
|
61 |
+
start_idx, end_idx = self.get_start_end_idx()
|
62 |
+
yield batch[start_idx:end_idx]
|
63 |
+
batch = []
|
64 |
+
|
65 |
+
# Check the last partial batch and see drop_last is set
|
66 |
+
if len(batch) > 0 and not self.drop_last:
|
67 |
+
start_idx, end_idx = self.get_start_end_idx()
|
68 |
+
yield batch[start_idx:end_idx]
|
69 |
+
|
70 |
+
|
71 |
+
class PretrainingRandomSampler:
|
72 |
+
|
73 |
+
def __init__(self, total_samples, consumed_samples, micro_batch_size,
|
74 |
+
data_parallel_rank, data_parallel_size, epoch):
|
75 |
+
# Keep a copy of input params for later use.
|
76 |
+
self.total_samples = total_samples
|
77 |
+
self.consumed_samples = consumed_samples
|
78 |
+
self.micro_batch_size = micro_batch_size
|
79 |
+
self.data_parallel_rank = data_parallel_rank
|
80 |
+
self.data_parallel_size = data_parallel_size
|
81 |
+
self.micro_batch_times_data_parallel_size = \
|
82 |
+
self.micro_batch_size * data_parallel_size
|
83 |
+
self.last_batch_size = \
|
84 |
+
self.total_samples % self.micro_batch_times_data_parallel_size
|
85 |
+
self.epoch = epoch
|
86 |
+
|
87 |
+
# Sanity checks.
|
88 |
+
assert self.total_samples > 0, \
|
89 |
+
'no sample to consume: {}'.format(self.total_samples)
|
90 |
+
assert self.micro_batch_size > 0
|
91 |
+
assert data_parallel_size > 0
|
92 |
+
assert self.data_parallel_rank < data_parallel_size, \
|
93 |
+
'data_parallel_rank should be smaller than data size: {}, ' \
|
94 |
+
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
95 |
+
|
96 |
+
def __len__(self):
|
97 |
+
return self.total_samples // self.micro_batch_times_data_parallel_size
|
98 |
+
|
99 |
+
def __iter__(self):
|
100 |
+
active_total_samples = self.total_samples - self.last_batch_size
|
101 |
+
current_epoch_samples = self.consumed_samples % active_total_samples
|
102 |
+
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
103 |
+
|
104 |
+
# data sharding and random sampling
|
105 |
+
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
|
106 |
+
* self.micro_batch_size
|
107 |
+
bucket_offset = current_epoch_samples // self.data_parallel_size
|
108 |
+
start_idx = self.data_parallel_rank * bucket_size
|
109 |
+
|
110 |
+
g = torch.Generator()
|
111 |
+
g.manual_seed(self.epoch)
|
112 |
+
random_idx = torch.randperm(bucket_size, generator=g).tolist()
|
113 |
+
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
|
114 |
+
|
115 |
+
batch = []
|
116 |
+
# Last batch if not complete will be dropped.
|
117 |
+
for idx in idx_range:
|
118 |
+
batch.append(idx)
|
119 |
+
if len(batch) == self.micro_batch_size:
|
120 |
+
self.consumed_samples += self.micro_batch_times_data_parallel_size
|
121 |
+
yield batch
|
122 |
+
batch = []
|
123 |
+
|
124 |
+
def set_epoch(self, epoch):
|
125 |
+
self.epoch = epoch
|
fengshen/examples/DAVAE/generate.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
Copyright 2022 The International Digital Economy Academy (IDEA). CCNL team. All rights reserved.
|
4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
you may not use this file except in compliance with the License.
|
6 |
+
You may obtain a copy of the License at
|
7 |
+
|
8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
Unless required by applicable law or agreed to in writing, software
|
11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
@File : generate.py
|
14 |
+
@Time : 2022/11/04 19:17
|
15 |
+
@Author : Liang Yuxin
|
16 |
+
@Version : 1.0
|
17 |
+
@Contact : liangyuxin@idea.edu.cn
|
18 |
+
@License : (C)Copyright 2022-2023, CCNL-IDEA
|
19 |
+
'''
|
20 |
+
# here put the import lib
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from fengshen.models.DAVAE.DAVAEModel import DAVAEModel
|
24 |
+
from transformers import BertTokenizer,T5Tokenizer
|
25 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
+
|
27 |
+
encoder_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Randeng-DAVAE-1.2B-General-Chinese")
|
28 |
+
decoder_tokenizer = T5Tokenizer.from_pretrained("IDEA-CCNL/Randeng-DAVAE-1.2B-General-Chinese", eos_token = '<|endoftext|>', pad_token = '<pad>',extra_ids=0)
|
29 |
+
decoder_tokenizer.add_special_tokens({'bos_token':'<bos>'})
|
30 |
+
vae_model = DAVAEModel.from_pretrained("IDEA-CCNL/Randeng-DAVAE-1.2B-General-Chinese").to(device)
|
31 |
+
input_texts = [
|
32 |
+
"针对电力系统中的混沌振荡对整个互联电网的危害问题,提出了一种基于非线性光滑函数的滑模控制方法.",
|
33 |
+
"超市面积不算大.挺方便附近的居民购买的. 生活用品也比较齐全.价格适用中.",
|
34 |
+
]
|
35 |
+
output_texts = vae_model.simulate_batch(encoder_tokenizer,decoder_tokenizer,input_texts)
|
36 |
+
print(output_texts)
|
fengshen/examples/FastDemo/README.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 「streamlit」快速搭建你的算法demo
|
2 |
+
在搭建demo之前,首先得做好这些准备工作:
|
3 |
+
- 模型训练完毕
|
4 |
+
- 模型的入参确定
|
5 |
+
- 安装streamlit库,`pip install streamlit` 就可以安装。
|
6 |
+
|
7 |
+
streamlit脚本的启动方式是 `streamlit run demo.py`,很简单就启动了一个demo页面,页面会随着脚本代码的改变实时刷新的。所以在没有经验的时候,可以创建一个demo.py的文件,照着下面的教程一步一步添加代码,看页面的展示情况。下面开始上干货,具体细节在代码注释中有说明!
|
8 |
+
|
9 |
+
### 第一步 导包
|
10 |
+
```python
|
11 |
+
import streamlit as st
|
12 |
+
# 其他包更具你的需要导入
|
13 |
+
```
|
14 |
+
[streamlit](https://streamlit.io)是一个用于构建机器学习、深度学习、数据可视化demo的python框架。它不需要你有web开发的经验,会写python就可以高效的开发你的demo。
|
15 |
+
|
16 |
+
### 第二步 页面导航信息以及布局配置
|
17 |
+
|
18 |
+
```python
|
19 |
+
st.set_page_config(
|
20 |
+
page_title="余元医疗问答", # 页面标签标题
|
21 |
+
page_icon=":shark:", # 页面标签图标
|
22 |
+
layout="wide", # 页面的布局
|
23 |
+
initial_sidebar_state="expanded", # 左侧的sidebar的布局方式
|
24 |
+
# 配置菜单按钮的信息
|
25 |
+
menu_items={
|
26 |
+
'Get Help': 'https://www.extremelycoolapp.com/help',
|
27 |
+
'Report a bug': "https://www.extremelycoolapp.com/bug",
|
28 |
+
'About': "# This is a header. This is an *extremely* cool app!"
|
29 |
+
}
|
30 |
+
)
|
31 |
+
```
|
32 |
+
这一步可以省略,如果想让app更加个性化,可以添加这些设置。
|
33 |
+
|
34 |
+
### 第三步 设置demo标题
|
35 |
+
```python
|
36 |
+
st.title('Demo for MedicalQA')
|
37 |
+
```
|
38 |
+
streamlit的每一个小组件对应于页面都有一个默认的样式展示。
|
39 |
+
|
40 |
+
### 第四步 配置demo的参数
|
41 |
+
|
42 |
+
```python
|
43 |
+
# 此处是用的sidebar,侧边栏作为参数配置模块
|
44 |
+
st.sidebar.header("参数配置")
|
45 |
+
# 这里是在sidebar里面创建了表单,每个表单一定有一个标题和提交按钮
|
46 |
+
sbform = st.sidebar.form("固定参数设置")
|
47 |
+
# slider是滑动条组建,可以配置数值型参数
|
48 |
+
n_sample = sbform.slider("设置返回条数",min_value=1,max_value=10,value=3)
|
49 |
+
text_length = sbform.slider('生成长度:',min_value=32,max_value=512,value=64,step=32)
|
50 |
+
text_level = sbform.slider('文本多样性:',min_value=0.1,max_value=1.0,value=0.9,step=0.1)
|
51 |
+
# number_input也可以配置数值型参数
|
52 |
+
model_id = sbform.number_input('选择模型号:',min_value=0,max_value=13,value=13,step=1)
|
53 |
+
# selectbox选择组建,只能选择配置的选项
|
54 |
+
trans = sbform.selectbox('选择翻译内核',['百度通用','医疗生物'])
|
55 |
+
# 提交表单的配置,这些参数的赋值才生效
|
56 |
+
sbform.form_submit_button("提交配置")
|
57 |
+
|
58 |
+
# 这里是页面中的参数配置,也是demo的主体之一
|
59 |
+
form = st.form("参数设置")
|
60 |
+
# 本demo是qa demo,所以要录入用户的文本输入,text_input组建可以实现
|
61 |
+
input_text = form.text_input('请输入你的问题:',value='',placeholder='例如:糖尿病的症状有哪些?')
|
62 |
+
form.form_submit_button("提交")
|
63 |
+
```
|
64 |
+
以上就把demo的参数基本配置完成了。
|
65 |
+
|
66 |
+
### 第五步 模型预测
|
67 |
+
```python
|
68 |
+
# 定义一个前向预测的方法
|
69 |
+
# @st.cache(suppress_st_warning=True)
|
70 |
+
def generate_qa(input_text,n_sample,model_id='7',length=64,translator='baidu',level=0.7):
|
71 |
+
# 这里我们是把模型用fastapi搭建了一个api服务
|
72 |
+
URL = 'http://192.168.190.63:6605/qa'
|
73 |
+
data = {
|
74 |
+
"text":input_text,"n_sample":n_sample,
|
75 |
+
"model_id":model_id,"length":length,
|
76 |
+
'translator':translator,'level':level
|
77 |
+
}
|
78 |
+
r = requests.get(URL,params=data)
|
79 |
+
return r.text
|
80 |
+
# 模型预测结果
|
81 |
+
results = generate_qa(input_text,n_sample,model_id=str(model_id),
|
82 |
+
translator=translator,length=text_length,level=text_level)
|
83 |
+
```
|
84 |
+
这里说明一下,由于demo展示机器没有GPU,所以模型部署采用的是Fastapi部署在后台的。如果demo展示的机器可以直接部署模型,这里可以直接把模型预测的方法写在这里,不需要另外部署模型,再用api的方式调用。这样做有一个值得注意的地方,因为streamlit的代码每一次运行,都是从头到尾执行一遍,就导致模型可能会重复加载,所以这里需要用到st.cache组建,当内容没有更新的时候,会把这一步的结果缓存,而不会重新执行。保证了效率不会因此而下降。
|
85 |
+
|
86 |
+
### 第六步 结果展示
|
87 |
+
```python
|
88 |
+
with st.spinner('老夫正在思考中🤔...'):
|
89 |
+
if input_text:
|
90 |
+
results = generate_qa(input_text,n_sample,model_id=str(model_id),
|
91 |
+
translator=translator,length=text_length,level=text_level)
|
92 |
+
for idx,item in enumerate(eval(results),start=1):
|
93 |
+
st.markdown(f"""
|
94 |
+
**候选回答「{idx}」:**\n
|
95 |
+
""")
|
96 |
+
st.info('中文:%s'%item['fy_next_sentence'])
|
97 |
+
st.info('英文:%s'%item['next_sentence'])
|
98 |
+
```
|
99 |
+
streamlit对不同格式的内容展示,有丰富的组建,对于文本可以用`st.markdown`组建以及`st.text`和`st.write`展示。更多组建和功能可以参考官方文档:https://docs.streamlit.io
|
100 |
+
|
101 |
+
至此,一个完整的demo展示就完成了。效果图如下:
|
102 |
+
|
103 |
+
![](./image/demo.png)
|
104 |
+
|
105 |
+
完整的代码可以参考:`Fengshenbang-LM/fengshen/examples/FastDemo/YuyuanQA.py`
|
fengshen/examples/FastDemo/YuyuanQA.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import langid
|
3 |
+
import streamlit as st
|
4 |
+
from translate import baiduTranslatorMedical
|
5 |
+
from translate import baiduTranslator
|
6 |
+
|
7 |
+
langid.set_languages(['en', 'zh'])
|
8 |
+
lang_dic = {'zh': 'en', 'en': 'zh'}
|
9 |
+
|
10 |
+
st.set_page_config(
|
11 |
+
page_title="余元医疗问答",
|
12 |
+
page_icon=":shark:",
|
13 |
+
# layout="wide",
|
14 |
+
initial_sidebar_state="expanded",
|
15 |
+
menu_items={
|
16 |
+
'Get Help': 'https://www.extremelycoolapp.com/help',
|
17 |
+
'Report a bug': "https://www.extremelycoolapp.com/bug",
|
18 |
+
'About': "# This is a header. This is an *extremely* cool app!"
|
19 |
+
}
|
20 |
+
)
|
21 |
+
st.title('Demo for MedicalQA')
|
22 |
+
|
23 |
+
|
24 |
+
st.sidebar.header("参数配置")
|
25 |
+
sbform = st.sidebar.form("固定参数设置")
|
26 |
+
n_sample = sbform.slider("设置返回条数", min_value=1, max_value=10, value=3)
|
27 |
+
text_length = sbform.slider('生成长度:', min_value=32, max_value=512, value=64, step=32)
|
28 |
+
text_level = sbform.slider('文本多样性:', min_value=0.1, max_value=1.0, value=0.9, step=0.1)
|
29 |
+
model_id = sbform.number_input('选择模型号:', min_value=0, max_value=13, value=13, step=1)
|
30 |
+
trans = sbform.selectbox('选择翻译内核', ['百度通用', '医疗生物'])
|
31 |
+
sbform.form_submit_button("配置")
|
32 |
+
|
33 |
+
|
34 |
+
form = st.form("参数设置")
|
35 |
+
input_text = form.text_input('请输入你的问题:', value='', placeholder='例如:糖尿病的症状有哪些?')
|
36 |
+
if trans == '百度通用':
|
37 |
+
translator = 'baidu_common'
|
38 |
+
else:
|
39 |
+
translator = 'baidu'
|
40 |
+
if input_text:
|
41 |
+
lang = langid.classify(input_text)[0]
|
42 |
+
if translator == 'baidu':
|
43 |
+
st.write('**你的问题是:**', baiduTranslatorMedical(input_text, src=lang, dest=lang_dic[lang]).text)
|
44 |
+
else:
|
45 |
+
st.write('**你的问题是:**', baiduTranslator(input_text, src=lang, dest=lang_dic[lang]).text)
|
46 |
+
|
47 |
+
form.form_submit_button("提交")
|
48 |
+
|
49 |
+
# @st.cache(suppress_st_warning=True)
|
50 |
+
|
51 |
+
|
52 |
+
def generate_qa(input_text, n_sample, model_id='7', length=64, translator='baidu', level=0.7):
|
53 |
+
# st.write('调用了generate函数')
|
54 |
+
URL = 'http://192.168.190.63:6605/qa'
|
55 |
+
data = {"text": input_text, "n_sample": n_sample, "model_id": model_id,
|
56 |
+
"length": length, 'translator': translator, 'level': level}
|
57 |
+
r = requests.get(URL, params=data)
|
58 |
+
return r.text
|
59 |
+
# my_bar = st.progress(80)
|
60 |
+
|
61 |
+
|
62 |
+
with st.spinner('老夫正在思考中🤔...'):
|
63 |
+
if input_text:
|
64 |
+
results = generate_qa(input_text, n_sample, model_id=str(model_id),
|
65 |
+
translator=translator, length=text_length, level=text_level)
|
66 |
+
for idx, item in enumerate(eval(results), start=1):
|
67 |
+
st.markdown(f"""
|
68 |
+
**候选回答「{idx}」:**\n
|
69 |
+
""")
|
70 |
+
st.info('中文:%s' % item['fy_next_sentence'])
|
71 |
+
st.info('英文:%s' % item['next_sentence'])
|
fengshen/examples/FastDemo/image/demo.png
ADDED
fengshen/examples/GAVAE/generate.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import BertTokenizer,T5Tokenizer
|
3 |
+
from fengshen.models.GAVAE.GAVAEModel import GAVAEModel
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
|
6 |
+
encoder_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Randeng-GAVAE-1.2B-Augmentation-Chinese")
|
7 |
+
decoder_tokenizer = T5Tokenizer.from_pretrained("IDEA-CCNL/Randeng-GAVAE-1.2B-Augmentation-Chinese", eos_token = '<|endoftext|>', pad_token = '<pad>',extra_ids=0)
|
8 |
+
decoder_tokenizer.add_special_tokens({'bos_token':'<bos>'})
|
9 |
+
input_texts = [
|
10 |
+
"非常好的一个博物馆,是我所有去过的博物馆里感觉最正规的一家,凭有效证件可以入馆,可以自助免费存小件物品,讲解员和馆内外的工作人员也非常认真,其他的服务人员也很热情,非常好的!馆内的藏品也让人非常震撼!希望继续保持~",
|
11 |
+
"这是我来长沙最最期待的一定要去的地方,总算今天特地去瞻仰千古遗容了,开车到门口大屏幕显示着门票已发完的字样,心里一惊以为今天是白来了。但进了停车场才知道凭停车卡和有效身份证里面也能领,停车还不花钱,真好。",
|
12 |
+
"地方很大 很气派~~可以逛很久~~~去的时候是免费的~不过要安检~~~里面的马王堆~幸追夫人~还是很不错的~~~~去的时候有一个吴越文化特别展~~~东西也很多~~~~~很好看",
|
13 |
+
"我们到达的时候是下午3点,门票已经发完了。当时正焦虑的不知道怎么办才好,门卫大哥给我们俩补办了门票,这才得以入馆。非常感谢!绝对不虚此行!相当震撼的展览!原来古人也化妆,还有假发。记忆最深的是那个藕汤。可惜真颜已不得见。",
|
14 |
+
"去过三次,个人认为这是长沙最值得去的地方,博物馆的重点就是辛追,遗憾的是,每次去我都会感到悲哀,虽然我三次去的时候都要门票,但是每次看到辛追,都觉得现代的人类不应该挖她出来,除了第一次我觉得辛追像刚死去一样,后来两次我觉得太惨不忍睹了。建议大家要去就早去,以后肯定越来越腐烂",
|
15 |
+
"上大学时候去的,当时学生证是半价25,后来凭有效证件就不要钱了。非常喜欢的一家博物馆,里面可看的东西很多,当然最吸引我的就是那个辛追夫人和“素纱单衣”,果然不是盖的~里面的讲解员大部分都是师大学历史类的,非常专业和有耐心。虽然不在长沙了,不过对那里还是很有感情的,赞~~~",
|
16 |
+
"这两年也有很多机会去博物馆。。。不过还是想说湖南省博物馆是非常有特色的。。。应该说整个展览分成两个部分吧。。。一个部分是马王堆的主体展。。。另一个就是湖南的一些考古发现。。。其实来省博大部分的游客还是冲着马王堆来的吧。。。博物馆也很有心的为每一批游客安排了讲解员。。。从马王堆的发现到马王堆出土文物的介绍再到最后棺木和辛追的介绍。。。真是上了一节很生动的历史课。",
|
17 |
+
"网上订票去的,还是很顺利的就进去了,里面挺清净的,外围的环境也不错,还有鸽子可以喂。那天不是很闹,兜了一圈感觉还是很顺畅的,老娘娘和金缕玉衣挺震撼的。到此一游还是挺需要的",
|
18 |
+
]
|
19 |
+
gavae_model = GAVAEModel.from_pretrained("IDEA-CCNL/Randeng-GAVAE-1.2B-Augmentation-Chinese").to(device)
|
20 |
+
gavae_model.train_gan(encoder_tokenizer,decoder_tokenizer,input_texts)
|
21 |
+
# n:输出样本数量
|
22 |
+
texts = gavae_model.generate(n=5)
|
23 |
+
print(texts)
|
fengshen/examples/PPVAE/generate.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import BertTokenizer,T5Tokenizer
|
3 |
+
from fengshen.models.PPVAE.pluginVAE import PPVAEModel
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
|
6 |
+
encoder_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Randeng-PPVAE-1.2B-Augmentation-Chinese")
|
7 |
+
decoder_tokenizer = T5Tokenizer.from_pretrained("IDEA-CCNL/Randeng-PPVAE-1.2B-Augmentation-Chinese", eos_token = '<|endoftext|>', pad_token = '<pad>',extra_ids=0)
|
8 |
+
decoder_tokenizer.add_special_tokens({'bos_token':'<bos>'})
|
9 |
+
ppvae_model = PPVAEModel.from_pretrained("IDEA-CCNL/Randeng-PPVAE-1.2B-Augmentation-Chinese").to(device)
|
10 |
+
input_texts = [
|
11 |
+
"非常好的一个博物馆,是我所有去过的博物馆里感觉最正规的一家,凭有效证件可以入馆,可以自助免费存小件物品,讲解员和馆内外的工作人员也非常认真,其他的服务人员也很热情,非常好的!馆内的藏品也让人非常震撼!希望继续保持~",
|
12 |
+
"这是我来长沙最最期待的一定要去的地方,总算今天特地去瞻仰千古遗容了,开车到门口大屏幕显示着门票已发完的字样,心里一惊以为今天是白来了。但进了停车场才知道凭停车卡和有效身份证里面也能领,停车还不花钱,真好。",
|
13 |
+
"地方很大 很气派~~可以逛很久~~~去的时候是免费的~不过要安检~~~里面的马王堆~幸追夫人~还是很不错的~~~~去的时候有一个吴越文化特别展~~~东西也很多~~~~~很好看",
|
14 |
+
"我们到达的时候是下午3点,门票已经发完了。当时正焦虑的不知道怎么办才好,门卫大哥给我们俩补办了门票,这才得以入馆。非常感谢!绝对不虚此行!相当震撼的展览!原来古人也化妆,还有假发。记忆最深的是那个藕汤。可惜真颜已不得见。",
|
15 |
+
"去过三次,个人认为这是长沙最值得去的地方,博物馆的重点就是辛追,遗憾的是,每次去我都会感到悲哀,虽然我三次去的时候都要门票,但是每次看到辛追,都觉得现代的人类不应该挖她出来,除了第一次我觉得辛追像刚死去一样,后来两次我觉得太惨不忍睹了。建议大家要去就早去,以后肯定越来越腐烂",
|
16 |
+
"上大学时候去的,当时学生证是半价25,后来凭有效证件就不要钱了。非常喜欢的一家博物馆,里面可看的东西很多,当然最吸引我的就是那个辛追夫人和“素纱单衣”,果然不是盖的~里面的讲解员大部分都是师大学历史类的,非常专业和有耐心。虽然不在长沙了,不过对那里还是很有感情的,赞~~~",
|
17 |
+
"这两年也有很多机会去博物馆。。。不过还是想说湖南省博物馆是非常有特色的。。。应该说整个展览分成两个部分吧。。。一个部分是马王堆的主体展。。。另一个就是湖南的一些考古发现。。。其实来省博大部分的游客还是冲着马王堆来的吧。。。博物馆也很有心的为每一批游客安排了讲解员。。。从马王堆的发现到马王堆出土文物的介绍再到最后棺木和辛追的介绍。。。真是上了一节很生动的历史课。",
|
18 |
+
"网上订票去的,还是很顺利的就进去了,里面挺清净的,外围的环境也不错,还有鸽子可以喂。那天不是很闹,兜了一圈感觉还是很顺畅的,老娘娘和金缕玉衣挺震撼的。到此一游还是挺需要的",
|
19 |
+
]
|
20 |
+
|
21 |
+
ppvae_model.train_plugin(encoder_tokenizer,decoder_tokenizer,input_texts,negative_samples=None)
|
22 |
+
# n:输出样本数量
|
23 |
+
texts = ppvae_model.generate(n=5)
|
24 |
+
print(texts)
|
fengshen/examples/classification/demo_classification_afqmc_erlangshen_offload.sh
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_NAME="IDEA-CCNL/Erlangshen-MegatronBert-1.3B"
|
2 |
+
|
3 |
+
TEXTA_NAME=sentence1
|
4 |
+
TEXTB_NAME=sentence2
|
5 |
+
LABEL_NAME=label
|
6 |
+
ID_NAME=id
|
7 |
+
|
8 |
+
BATCH_SIZE=1
|
9 |
+
VAL_BATCH_SIZE=1
|
10 |
+
ZERO_STAGE=3
|
11 |
+
config_json="./ds_config.json"
|
12 |
+
|
13 |
+
cat <<EOT > $config_json
|
14 |
+
{
|
15 |
+
"train_micro_batch_size_per_gpu": $BATCH_SIZE,
|
16 |
+
"steps_per_print": 1000,
|
17 |
+
"gradient_clipping": 1,
|
18 |
+
"zero_optimization": {
|
19 |
+
"stage": ${ZERO_STAGE},
|
20 |
+
"offload_optimizer": {
|
21 |
+
"device": "cpu",
|
22 |
+
"pin_memory": true
|
23 |
+
},
|
24 |
+
"offload_param": {
|
25 |
+
"device": "cpu",
|
26 |
+
"pin_memory": true
|
27 |
+
},
|
28 |
+
"overlap_comm": true,
|
29 |
+
"contiguous_gradients": true,
|
30 |
+
"sub_group_size": 1e9,
|
31 |
+
"stage3_max_live_parameters": 1e9,
|
32 |
+
"stage3_max_reuse_distance": 1e9
|
33 |
+
},
|
34 |
+
"zero_allow_untested_optimizer": false,
|
35 |
+
"fp16": {
|
36 |
+
"enabled": true,
|
37 |
+
"loss_scale": 0,
|
38 |
+
"loss_scale_window": 1000,
|
39 |
+
"hysteresis": 2,
|
40 |
+
"min_loss_scale": 1
|
41 |
+
},
|
42 |
+
"activation_checkpointing": {
|
43 |
+
"partition_activations": false,
|
44 |
+
"contiguous_memory_optimization": false
|
45 |
+
},
|
46 |
+
"wall_clock_breakdown": false
|
47 |
+
}
|
48 |
+
EOT
|
49 |
+
|
50 |
+
export PL_DEEPSPEED_CONFIG_PATH=$config_json
|
51 |
+
|
52 |
+
DATA_ARGS="\
|
53 |
+
--dataset_name IDEA-CCNL/AFQMC \
|
54 |
+
--train_batchsize $BATCH_SIZE \
|
55 |
+
--valid_batchsize $VAL_BATCH_SIZE \
|
56 |
+
--max_length 128 \
|
57 |
+
--texta_name $TEXTA_NAME \
|
58 |
+
--textb_name $TEXTB_NAME \
|
59 |
+
--label_name $LABEL_NAME \
|
60 |
+
--id_name $ID_NAME \
|
61 |
+
"
|
62 |
+
|
63 |
+
MODEL_ARGS="\
|
64 |
+
--learning_rate 1e-5 \
|
65 |
+
--weight_decay 1e-1 \
|
66 |
+
--warmup_ratio 0.01 \
|
67 |
+
--num_labels 2 \
|
68 |
+
--model_type huggingface-auto \
|
69 |
+
"
|
70 |
+
|
71 |
+
MODEL_CHECKPOINT_ARGS="\
|
72 |
+
--monitor val_acc \
|
73 |
+
--save_top_k 3 \
|
74 |
+
--mode max \
|
75 |
+
--every_n_train_steps 0 \
|
76 |
+
--save_weights_only True \
|
77 |
+
--dirpath . \
|
78 |
+
--filename model-{epoch:02d}-{val_acc:.4f} \
|
79 |
+
"
|
80 |
+
|
81 |
+
|
82 |
+
TRAINER_ARGS="\
|
83 |
+
--max_epochs 67 \
|
84 |
+
--gpus 1 \
|
85 |
+
--num_nodes 1 \
|
86 |
+
--strategy deepspeed_stage_${ZERO_STAGE}_offload \
|
87 |
+
--gradient_clip_val 1.0 \
|
88 |
+
--check_val_every_n_epoch 1 \
|
89 |
+
--val_check_interval 1.0 \
|
90 |
+
--precision 16 \
|
91 |
+
--default_root_dir . \
|
92 |
+
"
|
93 |
+
|
94 |
+
options=" \
|
95 |
+
--pretrained_model_path $MODEL_NAME \
|
96 |
+
$DATA_ARGS \
|
97 |
+
$MODEL_ARGS \
|
98 |
+
$MODEL_CHECKPOINT_ARGS \
|
99 |
+
$TRAINER_ARGS \
|
100 |
+
"
|
101 |
+
|
102 |
+
python3 finetune_classification.py $options
|
103 |
+
|