Upload 27 files
Browse files- .gitattributes +1 -0
- LICENSE +201 -0
- README.md +109 -3
- README_zh.md +0 -0
- checkpoints/MSAGPT-DPO/1/mp_rank_00_model_states.pt +3 -0
- checkpoints/MSAGPT-DPO/latest +1 -0
- checkpoints/MSAGPT-DPO/model_config.json +16 -0
- checkpoints/MSAGPT-SFT/1/mp_rank_00_model_states.pt +3 -0
- checkpoints/MSAGPT-SFT/latest +1 -0
- checkpoints/MSAGPT-SFT/model_config.json +16 -0
- checkpoints/MSAGPT/1/mp_rank_00_model_states.pt +3 -0
- checkpoints/MSAGPT/latest +1 -0
- checkpoints/MSAGPT/model_config.json +16 -0
- cli_sat.py +136 -0
- model_utils/__init__.py +2 -0
- model_utils/model_msagpt.py +30 -0
- model_utils/model_proteinglm_clm.py +428 -0
- msa_input +4 -0
- requirements.txt +3 -0
- resources/app_case.png +0 -0
- resources/demo.gif +3 -0
- resources/overall_frame.png +0 -0
- scripts/cli_sat.sh +62 -0
- utils/__init__.py +4 -0
- utils/chat.py +371 -0
- utils/strategies.py +229 -0
- utils/tokenization.py +213 -0
- utils/utils.py +7 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
resources/demo.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,109 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MSAGPT
|
2 |
+
|
3 |
+
<table>
|
4 |
+
<tr>
|
5 |
+
<td>
|
6 |
+
<h2>MSAGPT</h2>
|
7 |
+
<p>📖 Paper: <a href="xxx">MSAGPT: Neural Prompting Protein Structure Prediction via MSA Generative Pre-Training</a></p>
|
8 |
+
<p><b>MSAGPT</b> is a powerful protein language model (PLM). MSAGPT has 3 billion parameters with three versions of the model, MSAGPT, MSAGPT-Sft, and MSAGPT-Dpo, <b>supporting zero-shot and few-shot MSA generation</b>.</p>
|
9 |
+
<p><b>MSAGPT achieves state-of-the-art structural prediction performance on natural MSA-scarce scenarios</b>.</p>
|
10 |
+
</td>
|
11 |
+
</tr>
|
12 |
+
</table>
|
13 |
+
|
14 |
+
|
15 |
+
## Overall Framework
|
16 |
+
<p align="center">
|
17 |
+
<img src="resources/overall_frame.png" alt="描述文字" style="display: block; margin: auto; width: 90%;">
|
18 |
+
</p>
|
19 |
+
|
20 |
+
## Visualized Cases
|
21 |
+
Visualization of improved structure prediction compared with nature MSA.
|
22 |
+
<font color=orange>Yellow</font>: Ground truth;
|
23 |
+
<font color=purple>Purple</font>: Predictions based on MSA generated by MSAGPT;
|
24 |
+
<font color=cyan>Cyan</font>: Predictions from MSA generated by natural MSA.
|
25 |
+
|
26 |
+
<p align="center">
|
27 |
+
<img src="resources/app_case.png" alt="描述文字" style="display: block; margin: auto; width: 90%;">
|
28 |
+
</p>
|
29 |
+
|
30 |
+
|
31 |
+
## Get Started:
|
32 |
+
|
33 |
+
### Option 1:Deploy MSAGPT by yourself
|
34 |
+
|
35 |
+
We support GUI for model inference.
|
36 |
+
|
37 |
+
First, we need to install the dependencies.
|
38 |
+
|
39 |
+
```bash
|
40 |
+
# CUDA >= 11.8
|
41 |
+
pip install -r requirements.txt
|
42 |
+
```
|
43 |
+
|
44 |
+
#### Model List
|
45 |
+
You can choose to manually download the necessary weights. Then UNZIP it and put it into the **checkpoints** folder.
|
46 |
+
|
47 |
+
| Model | Type | Seq Length | Download |
|
48 |
+
|------------------|------|------------|-----------------------------------------------------------------------------------------------------------------------------------------|
|
49 |
+
| MSAGPT | Base | 16K | [🤗 Huggingface](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) [🔨 SwissArmyTransformer](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) |
|
50 |
+
| MSAGPT-SFT | Sft | 16K | [🤗 Huggingface](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) [🔨 SwissArmyTransformer](https://cloud.tsinghua.edu.cn/f/32da3eadf6e042aab2fa/?dl=1) |
|
51 |
+
| MSAGPT-DPO | Rlhf | 16K | [🤗 Huggingface](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) [🔨 SwissArmyTransformer](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) | | |
|
52 |
+
|
53 |
+
|
54 |
+
#### Situation 1.1 CLI (SAT version)
|
55 |
+
|
56 |
+
Run CLI demo via:
|
57 |
+
|
58 |
+
```bash
|
59 |
+
# Online Chat
|
60 |
+
bash scripts/cli_sat.sh --from_pretrained ./checkpoints/MSAGPT-DPO --input-source chat --stream_chat --max-gen-length 1024
|
61 |
+
```
|
62 |
+
|
63 |
+
The program will automatically interact in the command line. You can generate replies entering the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by "\<M\>"), for example: "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG\<M\>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG", where "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG" is the main sequence, and "VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG" are MSA prompts, and pressing enter. Enter `stop` to stop the program. The chat CLI looks like:
|
64 |
+
<p align="center">
|
65 |
+
<img src="resources/demo.gif" alt="描述文字" style="display: block; margin: auto; width: 90%;">
|
66 |
+
</p>
|
67 |
+
|
68 |
+
|
69 |
+
You can also enable the offline generation by set the **--input-source \<your input file\>** and **--output-path \<your output path\>**.
|
70 |
+
We set an input file example: *msa_input*.
|
71 |
+
```bash
|
72 |
+
# Offline Generation
|
73 |
+
bash scripts/cli_sat.sh --from_pretrained ./checkpoints/MSAGPT-DPO --input-source <your input file> --output-path <your output path> --max-gen-length 1024
|
74 |
+
```
|
75 |
+
|
76 |
+
#### Situation 1.2 CLI (Huggingface version)
|
77 |
+
(TODO)
|
78 |
+
|
79 |
+
#### Situation 1.3 Web Demo
|
80 |
+
(TODO)
|
81 |
+
|
82 |
+
### Option 2:Finetuning MSAGPT
|
83 |
+
|
84 |
+
(TODO)
|
85 |
+
|
86 |
+
### Hardware requirement
|
87 |
+
|
88 |
+
* Model Inference:
|
89 |
+
For BF16: 1 * A100(80G)
|
90 |
+
|
91 |
+
* Finetuning:
|
92 |
+
|
93 |
+
For BF16: 4 * A100(80G) *[Recommend]*.
|
94 |
+
|
95 |
+
|
96 |
+
## License
|
97 |
+
|
98 |
+
The code in this repository is open source under the [Apache-2.0 license](./LICENSE).
|
99 |
+
|
100 |
+
If you find our work helpful, please consider citing the our paper
|
101 |
+
|
102 |
+
```
|
103 |
+
@article{chen2024msagpt,
|
104 |
+
title={MSAGPT: Neural Prompting Protein Structure Prediction via MSA Generative Pre-Training},
|
105 |
+
author={Chen, Bo and Bei, Zhilei and Cheng, Xingyi and Li, Pan and Tang, Jie and Song, Le},
|
106 |
+
journal={arXiv preprint arXiv:2406.05347},
|
107 |
+
year={2024}
|
108 |
+
}
|
109 |
+
```
|
README_zh.md
ADDED
File without changes
|
checkpoints/MSAGPT-DPO/1/mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f3507871a00564c0be3a697678f521eccee2efb2d77577b0bc009d766b8f02a4
|
3 |
+
size 5721204666
|
checkpoints/MSAGPT-DPO/latest
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
checkpoints/MSAGPT-DPO/model_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_class": "MSAGPT",
|
3 |
+
"tokenizer_type": "ProteinTokenizer",
|
4 |
+
"num_layers": 36,
|
5 |
+
"hidden_size": 2560,
|
6 |
+
"inner_hidden_size": 6832,
|
7 |
+
"num_attention_heads": 40,
|
8 |
+
"vocab_size": 128,
|
9 |
+
"layernorm_order": "post",
|
10 |
+
"model_parallel_size": 1,
|
11 |
+
"max_sequence_length": 2048,
|
12 |
+
"untie_head": true,
|
13 |
+
"head_num": 2,
|
14 |
+
"moe": false,
|
15 |
+
"expert": 1
|
16 |
+
}
|
checkpoints/MSAGPT-SFT/1/mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:19b7a79194615affec18617b2854602f2b77f053b80b44b31f6fd79bfb38ae68
|
3 |
+
size 5721204666
|
checkpoints/MSAGPT-SFT/latest
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
checkpoints/MSAGPT-SFT/model_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_class": "MSAGPT",
|
3 |
+
"tokenizer_type": "ProteinTokenizer",
|
4 |
+
"num_layers": 36,
|
5 |
+
"hidden_size": 2560,
|
6 |
+
"inner_hidden_size": 6832,
|
7 |
+
"num_attention_heads": 40,
|
8 |
+
"vocab_size": 128,
|
9 |
+
"layernorm_order": "post",
|
10 |
+
"model_parallel_size": 1,
|
11 |
+
"max_sequence_length": 2048,
|
12 |
+
"untie_head": true,
|
13 |
+
"head_num": 2,
|
14 |
+
"moe": false,
|
15 |
+
"expert": 1
|
16 |
+
}
|
checkpoints/MSAGPT/1/mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:daaec07dca52dda4eaee8442d02c9c0f821a5e8ad81cbd280490f50f8f16e205
|
3 |
+
size 5721204666
|
checkpoints/MSAGPT/latest
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
checkpoints/MSAGPT/model_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_class": "MSAGPT",
|
3 |
+
"tokenizer_type": "ProteinTokenizer",
|
4 |
+
"num_layers": 36,
|
5 |
+
"hidden_size": 2560,
|
6 |
+
"inner_hidden_size": 6832,
|
7 |
+
"num_attention_heads": 40,
|
8 |
+
"vocab_size": 128,
|
9 |
+
"layernorm_order": "post",
|
10 |
+
"model_parallel_size": 1,
|
11 |
+
"max_sequence_length": 2048,
|
12 |
+
"untie_head": true,
|
13 |
+
"head_num": 2,
|
14 |
+
"moe": false,
|
15 |
+
"expert": 1
|
16 |
+
}
|
cli_sat.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import stat
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
from typing import List, Tuple
|
11 |
+
|
12 |
+
import torch.distributed as dist
|
13 |
+
from sat.helpers import print_rank0
|
14 |
+
from sat import mpu, get_args, get_tokenizer
|
15 |
+
from utils import AdvancedBaseStrategy, BeamSearchStrategy
|
16 |
+
from model_utils import MSAGPT, FineTuneMSAGPT
|
17 |
+
from utils import chat_api
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
py_parser = argparse.ArgumentParser(add_help=False)
|
23 |
+
py_parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
|
24 |
+
py_parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
|
25 |
+
py_parser.add_argument("--max-gen-length", type=int, default=512, help="The minimum length each blank should generate.")
|
26 |
+
py_parser.add_argument("--is-valid", action="store_true", help="Print all output generated by beam search strategy.")
|
27 |
+
py_parser.add_argument("--print-all-beams", action="store_true", help="Print all output generated by beam search strategy.")
|
28 |
+
py_parser.add_argument("--multiline_stream", action="store_true", help="streaming multiline output.")
|
29 |
+
py_parser.add_argument("--no-gap", action="store_true", help="do not generate gaps.")
|
30 |
+
py_parser.add_argument("--from_pretrained", type=str, default="./checkpoints/MSAGPT", help='pretrained ckpt')
|
31 |
+
py_parser.add_argument("--chinese", action='store_true', help='Chinese interface')
|
32 |
+
py_parser.add_argument("--stream_chat", action='store_true', help='streaming output')
|
33 |
+
|
34 |
+
|
35 |
+
py_parser = MSAGPT.add_model_specific_args(py_parser)
|
36 |
+
known, args_list = py_parser.parse_known_args()
|
37 |
+
args = get_args(args_list)
|
38 |
+
args = argparse.Namespace(**vars(args), **vars(known))
|
39 |
+
model, args = MSAGPT.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {})
|
40 |
+
model.eval()
|
41 |
+
rank = int(os.environ.get('RANK', 0))
|
42 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
model = model.to('cuda')
|
45 |
+
from utils import proteinglm_tokenizer
|
46 |
+
tokenizer = proteinglm_tokenizer()
|
47 |
+
|
48 |
+
end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
|
49 |
+
# Get rid of all invalid tokens
|
50 |
+
invalid_slices = [0,26,28,29,30,31,32]
|
51 |
+
if args.no_gap:
|
52 |
+
invalid_slices.append(tokenizer.TokenToId('-'))
|
53 |
+
if args.sampling_strategy == "BaseStrategy":
|
54 |
+
assert not args.print_all_beams, "BaseStrategy don't support print all beams."
|
55 |
+
strategy = AdvancedBaseStrategy(
|
56 |
+
batch_size=1, invalid_slices = invalid_slices, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, min_gen_length=args.min_gen_length, no_repeat_ngram_size=args.no_repeat_ngram_size, end_tokens=end_tokens
|
57 |
+
)
|
58 |
+
elif args.sampling_strategy == "BeamSearchStrategy":
|
59 |
+
strategy = BeamSearchStrategy(
|
60 |
+
1,
|
61 |
+
args.num_beams,
|
62 |
+
length_penalty=args.length_penalty,
|
63 |
+
consider_end=True,
|
64 |
+
end_tokens=end_tokens,
|
65 |
+
invalid_slices=invalid_slices,
|
66 |
+
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
67 |
+
min_gen_length=args.min_gen_length,
|
68 |
+
deterministic=True
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
raise ValueError(f"unknown strategy {args.sampling_strategy}")
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
if args.input_source == 'chat':
|
76 |
+
if args.chinese:
|
77 |
+
if rank == 0:
|
78 |
+
print('欢迎使用 MSAGPT-CLI ,输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以"<M>"相连),例如:"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG",其中"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG"为主序列,"VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG"为MSA prompt。 stop 终止程序'.center(20, "*"))
|
79 |
+
else:
|
80 |
+
if rank == 0:
|
81 |
+
print('Welcome to MSAGPT-CLI. Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by "<M>"), for example: "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG", where "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG" is the main sequence, and "VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG" are MSA prompts. Type "stop" to end the program.'.center(20,"*"))
|
82 |
+
with torch.no_grad():
|
83 |
+
while True:
|
84 |
+
if args.chinese:
|
85 |
+
if rank == 0:
|
86 |
+
protein_input = input("请输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以'<M>'相连):")
|
87 |
+
else:
|
88 |
+
protein_input = None
|
89 |
+
else:
|
90 |
+
if rank == 0:
|
91 |
+
protein_input = input("Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by '<M>': ")
|
92 |
+
else:
|
93 |
+
protein_input = None
|
94 |
+
if world_size > 1:
|
95 |
+
torch.distributed.broadcast_object(protein_input, 0)
|
96 |
+
protein_input = protein_input.strip()
|
97 |
+
assert protein_input is not None
|
98 |
+
|
99 |
+
if protein_input == 'stop':
|
100 |
+
break
|
101 |
+
|
102 |
+
try:
|
103 |
+
response = chat_api(
|
104 |
+
args=args,
|
105 |
+
query=protein_input,
|
106 |
+
model=model,
|
107 |
+
tokenizer=tokenizer,
|
108 |
+
strategy=strategy
|
109 |
+
)
|
110 |
+
except Exception as e:
|
111 |
+
print(e)
|
112 |
+
break
|
113 |
+
if rank == 0 and not args.stream_chat:
|
114 |
+
if args.chinese:
|
115 |
+
print(f"{'生成的MSA'.center(20, '*')}")
|
116 |
+
else:
|
117 |
+
print(f"{'Virtual MSA'.center(20, '*')}")
|
118 |
+
if args.print_all_beams:
|
119 |
+
for idx, gen in enumerate(response):
|
120 |
+
out_str = f"Beam: {idx}".center(11,'@')
|
121 |
+
print(out_str)
|
122 |
+
for _ in gen:
|
123 |
+
print(_)
|
124 |
+
print()
|
125 |
+
else:
|
126 |
+
response = response[0]
|
127 |
+
for _ in response:
|
128 |
+
print(_)
|
129 |
+
print()
|
130 |
+
else:
|
131 |
+
chat_api(
|
132 |
+
args=args,
|
133 |
+
model=model,
|
134 |
+
tokenizer=tokenizer,
|
135 |
+
strategy=strategy
|
136 |
+
)
|
model_utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .model_proteinglm_clm import ProteinGLMForGeneration
|
2 |
+
from .model_msagpt import MSAGPT, FineTuneMSAGPT
|
model_utils/model_msagpt.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from .model_proteinglm_clm import ProteinGLMForGeneration
|
8 |
+
|
9 |
+
|
10 |
+
class MSAGPT(ProteinGLMForGeneration):
|
11 |
+
def __init__(self, args, transformer=None, **kwargs):
|
12 |
+
super().__init__(
|
13 |
+
args,
|
14 |
+
transformer=transformer,
|
15 |
+
**kwargs
|
16 |
+
)
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def add_model_specific_args(cls, parser):
|
20 |
+
group = parser.add_argument_group('MSAGPT-inference', 'MSAGPT inference Configurations')
|
21 |
+
return super().add_model_specific_args(parser)
|
22 |
+
|
23 |
+
class FineTuneMSAGPT(MSAGPT):
|
24 |
+
def __init__(self, args, transformer=None, **kwargs):
|
25 |
+
super().__init__(
|
26 |
+
args,
|
27 |
+
transformer=transformer,
|
28 |
+
**kwargs
|
29 |
+
)
|
30 |
+
pass
|
model_utils/model_proteinglm_clm.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import torch.nn as nn
|
6 |
+
import contextlib
|
7 |
+
|
8 |
+
from sat import mpu
|
9 |
+
from sat.transformer_defaults import standard_attention, attention_fn_default
|
10 |
+
from sat.mpu.utils import split_tensor_along_last_dim, divide
|
11 |
+
from sat.mpu.layers import ColumnParallelLinear
|
12 |
+
from sat.model.base_model import BaseModel, BaseMixin
|
13 |
+
from sat.model.position_embedding import RotaryEmbedding
|
14 |
+
from sat.model.position_embedding import apply_rotary_pos_emb_index
|
15 |
+
from sat.ops import LayerNorm
|
16 |
+
|
17 |
+
|
18 |
+
class RotaryEmbeddingMixin(BaseMixin):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
fp16,
|
22 |
+
hidden_size,
|
23 |
+
num_attention_heads,
|
24 |
+
model_parallel_size,
|
25 |
+
rotary_embedding_2d=True,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
|
29 |
+
self.hidden_size_per_attention_head = hidden_size_per_attention_head
|
30 |
+
self.rotary_embedding_2d = rotary_embedding_2d
|
31 |
+
self.num_attention_heads_per_partition = divide(num_attention_heads, model_parallel_size)
|
32 |
+
self.rotary_emb = RotaryEmbedding(
|
33 |
+
# hidden_size_per_attention_head,
|
34 |
+
hidden_size_per_attention_head // 2
|
35 |
+
if rotary_embedding_2d
|
36 |
+
else hidden_size_per_attention_head,
|
37 |
+
base=10000,
|
38 |
+
precision=torch.half if fp16 else torch.bfloat16,
|
39 |
+
learnable=False,
|
40 |
+
device=torch.cuda.current_device(),
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def attention_forward(self, hidden_states, mask, **kw_args):
|
45 |
+
attn = self.transformer.layers[kw_args["layer_id"]].attention
|
46 |
+
attention_fn = attention_fn_default
|
47 |
+
if "attention_fn" in attn.hooks:
|
48 |
+
attention_fn = attn.hooks["attention_fn"]
|
49 |
+
|
50 |
+
# [seq, b, 3 * hn * np]
|
51 |
+
mixed_raw_layer = attn.query_key_value(hidden_states)
|
52 |
+
|
53 |
+
# [seq, b, (np * 3 * hn)] --> [seq, b, np, 3 * hn]
|
54 |
+
new_tensor_shape = mixed_raw_layer.size()[:-1] + (
|
55 |
+
self.num_attention_heads_per_partition,
|
56 |
+
3 * self.hidden_size_per_attention_head,
|
57 |
+
)
|
58 |
+
mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
|
59 |
+
|
60 |
+
# [sq, b, np, hn]
|
61 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
62 |
+
# print(key_layer.shape)
|
63 |
+
dropout_fn = attn.attention_dropout if attn.training else None
|
64 |
+
if self.rotary_embedding_2d:
|
65 |
+
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
|
66 |
+
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
|
67 |
+
cos, sin = self.rotary_emb(q1, seq_len=kw_args["position_ids"].max() + 1)
|
68 |
+
position_ids, block_position_ids = \
|
69 |
+
kw_args["position_ids"][:, 0, :].transpose(0, 1).contiguous(), \
|
70 |
+
kw_args["position_ids"][:, 1, :].transpose(0, 1).contiguous()
|
71 |
+
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
|
72 |
+
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
|
73 |
+
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
|
74 |
+
key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
|
75 |
+
else:
|
76 |
+
kw_args["position_ids"] = kw_args["position_ids"].transpose(0, 1)
|
77 |
+
cos, sin = self.rotary_emb(value_layer, seq_len=kw_args["position_ids"].max() + 1)
|
78 |
+
query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, kw_args["position_ids"])
|
79 |
+
|
80 |
+
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
|
81 |
+
output = attn.dense(context_layer)
|
82 |
+
|
83 |
+
if attn.training:
|
84 |
+
output = attn.output_dropout(output)
|
85 |
+
|
86 |
+
return output
|
87 |
+
|
88 |
+
|
89 |
+
class GEGLU(torch.nn.Module):
|
90 |
+
def __init__(self):
|
91 |
+
super().__init__()
|
92 |
+
self.activation_fn = F.gelu
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
# dim=-1 breaks in jit for pt<1.10
|
96 |
+
x1, x2 = x.chunk(2, dim=(x.ndim - 1))
|
97 |
+
return x1 * self.activation_fn(x2)
|
98 |
+
|
99 |
+
|
100 |
+
class DeepNormWithGLUMixin(BaseMixin):
|
101 |
+
def __init__(self, num_layers, hidden_size, inner_hidden_size=None):
|
102 |
+
super().__init__()
|
103 |
+
self.num_layers = num_layers
|
104 |
+
self.hidden_size = hidden_size
|
105 |
+
if inner_hidden_size is None:
|
106 |
+
inner_hidden_size = 4 * hidden_size * 2 // 3
|
107 |
+
self.inner_hidden_size = inner_hidden_size
|
108 |
+
|
109 |
+
def reinit(self):
|
110 |
+
for layer in self.transformer.layers:
|
111 |
+
del layer.mlp.dense_h_to_4h
|
112 |
+
layer.mlp.dense_h_to_4h = ColumnParallelLinear(
|
113 |
+
self.hidden_size,
|
114 |
+
2 * self.inner_hidden_size,
|
115 |
+
gather_output=False,
|
116 |
+
bias=True,
|
117 |
+
params_dtype=torch.half,
|
118 |
+
module=self,
|
119 |
+
name="dense_h_to_4h",
|
120 |
+
skip_init=True,
|
121 |
+
)
|
122 |
+
del layer.mlp.activation_func
|
123 |
+
layer.mlp.activation_func = GEGLU()
|
124 |
+
|
125 |
+
def layer_forward(self, hidden_states, mask, *args, **kw_args):
|
126 |
+
"""
|
127 |
+
hidden_states: [seq_len, batch, hidden_size]
|
128 |
+
mask: [(1, 1), seq_len, seq_len]
|
129 |
+
"""
|
130 |
+
layer = self.transformer.layers[kw_args["layer_id"]]
|
131 |
+
# Layer norm at the begining of the transformer layer.
|
132 |
+
|
133 |
+
attention_input = layer.input_layernorm(hidden_states)
|
134 |
+
|
135 |
+
# Self attention.
|
136 |
+
attention_output = layer.attention(attention_input, mask, **kw_args)
|
137 |
+
|
138 |
+
# Residual connection.
|
139 |
+
alpha = (2 * self.num_layers) ** 0.5
|
140 |
+
hidden_states = attention_input * alpha + attention_output
|
141 |
+
|
142 |
+
mlp_input = layer.post_attention_layernorm(hidden_states)
|
143 |
+
|
144 |
+
# MLP.
|
145 |
+
mlp_output = layer.mlp(mlp_input, **kw_args)
|
146 |
+
|
147 |
+
# Second residual connection.
|
148 |
+
output = mlp_input * alpha + mlp_output
|
149 |
+
|
150 |
+
return output
|
151 |
+
|
152 |
+
|
153 |
+
class SelfAttentionWithFP32SoftmaxMixin(BaseMixin):
|
154 |
+
def __init__(self, fp16, hidden_size, num_attention_heads, model_parallel_size):
|
155 |
+
super().__init__()
|
156 |
+
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
|
157 |
+
self.hidden_size_per_partition = divide(hidden_size, model_parallel_size)
|
158 |
+
self.scale_mask_softmax = None
|
159 |
+
self.fp16 = fp16
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def attention_mask_func(attention_scores, attention_mask):
|
163 |
+
attention_scores.masked_fill_(attention_mask, -10000.0)
|
164 |
+
return attention_scores
|
165 |
+
|
166 |
+
def attention_fn(
|
167 |
+
self,
|
168 |
+
query_layer,
|
169 |
+
key_layer,
|
170 |
+
value_layer,
|
171 |
+
attention_mask,
|
172 |
+
attention_dropout=None,
|
173 |
+
log_attention_weights=None,
|
174 |
+
scaling_attention_score=True,
|
175 |
+
mems=None,
|
176 |
+
**kwargs
|
177 |
+
):
|
178 |
+
|
179 |
+
mem = mems[kwargs["layer_id"]] if mems is not None else None
|
180 |
+
|
181 |
+
# seqlen, batch, head, hidden_size
|
182 |
+
seq_len, b, nh, hidden_size = key_layer.shape
|
183 |
+
|
184 |
+
# stack, seqlen, b, head, hidden
|
185 |
+
# b, seqlen, stack, head, hidden
|
186 |
+
cache_kv = (
|
187 |
+
torch.stack((key_layer, value_layer))
|
188 |
+
.permute(2, 1, 0, 3, 4)
|
189 |
+
.detach()
|
190 |
+
.contiguous()
|
191 |
+
.view(b, seq_len, nh * hidden_size * 2)
|
192 |
+
)
|
193 |
+
kwargs["output_this_layer"]["mem_kv"] = cache_kv
|
194 |
+
|
195 |
+
if mem is not None: # the first time, mem is None
|
196 |
+
# might change batch_size
|
197 |
+
# b, seqlen, stack, head, hidden -> stack, seqlen, b, head, hidden
|
198 |
+
mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 1, 0, 3, 4)
|
199 |
+
memk, memv = mem[0], mem[1]
|
200 |
+
key_layer = torch.cat((memk, key_layer), dim=0)
|
201 |
+
value_layer = torch.cat((memv, value_layer), dim=0)
|
202 |
+
|
203 |
+
|
204 |
+
# check if use flash attention
|
205 |
+
is_low_triangle = (attention_mask == ~torch.ones_like(attention_mask, dtype=torch.bool).tril()).all()
|
206 |
+
is_full = (attention_mask is None) or (attention_mask == 0).all()
|
207 |
+
if int(torch.__version__.split('.')[0]) >= 2 and (is_full or is_low_triangle):
|
208 |
+
# Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
|
209 |
+
dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
|
210 |
+
#[b, np, sq, hn]
|
211 |
+
query_layer, key_layer, value_layer = query_layer.permute(1,2,0,3).contiguous(), key_layer.permute(1,2,0,3).contiguous(), value_layer.permute(1,2,0,3).contiguous()
|
212 |
+
batch_size, num_query_heads = query_layer.shape[:2] # [b, np, s, hn]
|
213 |
+
num_kv_heads = key_layer.shape[1] # [b, np, s, hn]
|
214 |
+
key_layer = key_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *key_layer.shape[2:])
|
215 |
+
value_layer = value_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *value_layer.shape[2:])
|
216 |
+
|
217 |
+
if dropout_p > 0 and mpu.get_cuda_rng_tracker is not None:
|
218 |
+
context = mpu.get_cuda_rng_tracker().fork()
|
219 |
+
else:
|
220 |
+
context = contextlib.nullcontext()
|
221 |
+
|
222 |
+
with context:
|
223 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
224 |
+
query_layer, key_layer, value_layer,
|
225 |
+
attn_mask=None,
|
226 |
+
dropout_p=dropout_p,
|
227 |
+
is_causal=not is_full
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
#[sq, b, np, hn]
|
232 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
233 |
+
|
234 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
235 |
+
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
236 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
237 |
+
return context_layer
|
238 |
+
|
239 |
+
else:
|
240 |
+
# standard attention
|
241 |
+
# [b, np, sq, sk]
|
242 |
+
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
243 |
+
|
244 |
+
query_key_layer_scaling_coeff = float(kwargs["layer_id"] + 1)
|
245 |
+
|
246 |
+
|
247 |
+
if scaling_attention_score:
|
248 |
+
query_layer = query_layer / (math.sqrt(self.hidden_size_per_attention_head) * query_key_layer_scaling_coeff)
|
249 |
+
# ===================================
|
250 |
+
# Raw attention scores. [b, np, s, s]
|
251 |
+
# ===================================
|
252 |
+
# [sq, b, np, hn] -> [sq, b * np, hn]
|
253 |
+
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
254 |
+
# [sk, b, np, hn] -> [sk, b * np, hn]
|
255 |
+
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
256 |
+
|
257 |
+
matmul_result = torch.empty(
|
258 |
+
output_size[0] * output_size[1],
|
259 |
+
output_size[2],
|
260 |
+
output_size[3],
|
261 |
+
dtype=query_layer.dtype,
|
262 |
+
device=torch.cuda.current_device(),
|
263 |
+
)
|
264 |
+
|
265 |
+
matmul_result = torch.baddbmm(
|
266 |
+
matmul_result,
|
267 |
+
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
268 |
+
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
269 |
+
beta=0.0,
|
270 |
+
alpha=1.0,
|
271 |
+
)
|
272 |
+
|
273 |
+
# change view to [b, np, sq, sk]
|
274 |
+
attention_scores = matmul_result.view(*output_size)
|
275 |
+
|
276 |
+
if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
|
277 |
+
# if auto-regressive, skip
|
278 |
+
attention_scores.masked_fill_(attention_mask.bool(), -float("inf"))
|
279 |
+
|
280 |
+
attention_scores = attention_scores.float()
|
281 |
+
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
282 |
+
|
283 |
+
|
284 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
285 |
+
|
286 |
+
if self.fp16:
|
287 |
+
attention_probs = attention_probs.half()
|
288 |
+
else:
|
289 |
+
attention_probs = attention_probs.bfloat16()
|
290 |
+
|
291 |
+
if attention_dropout is not None:
|
292 |
+
if mpu.get_cuda_rng_tracker() is not None:
|
293 |
+
with mpu.get_cuda_rng_tracker().fork():
|
294 |
+
attention_probs = attention_dropout(attention_probs)
|
295 |
+
else:
|
296 |
+
attention_probs = attention_dropout(attention_probs)
|
297 |
+
|
298 |
+
# =========================
|
299 |
+
# Context layer. [sq, b, hp]
|
300 |
+
# =========================
|
301 |
+
|
302 |
+
# value_layer -> context layer.
|
303 |
+
# [sk, b, np, hn] --> [b, np, sq, hn]
|
304 |
+
|
305 |
+
# context layer shape: [b, np, sq, hn]
|
306 |
+
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
307 |
+
|
308 |
+
# change view [sk, b * np, hn]
|
309 |
+
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
310 |
+
|
311 |
+
# change view [b * np, sq, sk]
|
312 |
+
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
313 |
+
# matmul: [b * np, sq, hn]
|
314 |
+
|
315 |
+
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
316 |
+
|
317 |
+
# change view [b, np, sq, hn]
|
318 |
+
context_layer = context_layer.view(*output_size)
|
319 |
+
|
320 |
+
# [b, np, sq, hn] --> [sq, b, np, hn]
|
321 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
322 |
+
|
323 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
324 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
325 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
326 |
+
return context_layer
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
class FinalForwardMixin(BaseMixin):
|
331 |
+
def __init__(self):
|
332 |
+
super().__init__()
|
333 |
+
|
334 |
+
def final_forward(self, logits, **kw_args):
|
335 |
+
return F.linear(logits, self.transformer.word_embeddings.weight).transpose(0, 1).contiguous()
|
336 |
+
|
337 |
+
|
338 |
+
class UntieFinalForwardMixin(BaseMixin):
|
339 |
+
def __init__(self, hidden_size, vocab_size, untie_head_num, layernorm_epsilon=1.0e-5):
|
340 |
+
super().__init__()
|
341 |
+
|
342 |
+
self.lm_head = nn.ModuleList()
|
343 |
+
for i in range(untie_head_num):
|
344 |
+
self.lm_head.append(
|
345 |
+
ColumnParallelLinear(
|
346 |
+
hidden_size,
|
347 |
+
2 * hidden_size,
|
348 |
+
gather_output=True,
|
349 |
+
bias=False,
|
350 |
+
module=self,
|
351 |
+
name=f"lm_head.{i}",
|
352 |
+
)
|
353 |
+
) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
|
354 |
+
|
355 |
+
self.head_layernorm = nn.ModuleList()
|
356 |
+
for i in range(untie_head_num):
|
357 |
+
self.head_layernorm.append(
|
358 |
+
LayerNorm(
|
359 |
+
hidden_size,
|
360 |
+
eps=layernorm_epsilon
|
361 |
+
)
|
362 |
+
)
|
363 |
+
self.activation_func=GEGLU()
|
364 |
+
|
365 |
+
|
366 |
+
def final_forward(self, logits, **kwargs):
|
367 |
+
logits = self.lm_head[1](logits)
|
368 |
+
logits = self.activation_func(logits)
|
369 |
+
logits = self.head_layernorm[1](logits)
|
370 |
+
return F.linear(logits, self.transformer.word_embeddings.weight).transpose(0, 1).contiguous()
|
371 |
+
|
372 |
+
|
373 |
+
class NonePositionEmbedding(BaseMixin):
|
374 |
+
def __init__(self):
|
375 |
+
super().__init__()
|
376 |
+
|
377 |
+
def position_embedding_forward(self, position_ids, output_cross_layer, **kw_args):
|
378 |
+
return None
|
379 |
+
|
380 |
+
|
381 |
+
class WordEmbedding(BaseMixin):
|
382 |
+
def __init__(self):
|
383 |
+
super().__init__()
|
384 |
+
|
385 |
+
def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args):
|
386 |
+
return self.transformer.word_embeddings(input_ids).transpose(0, 1)
|
387 |
+
|
388 |
+
|
389 |
+
class ProteinGLMForGeneration(BaseModel):
|
390 |
+
def __init__(self, args, transformer=None, **kwargs):
|
391 |
+
super().__init__(
|
392 |
+
args,
|
393 |
+
transformer=transformer,
|
394 |
+
**kwargs
|
395 |
+
)
|
396 |
+
self.add_mixin("glu-deepnorm", DeepNormWithGLUMixin(args.num_layers, args.hidden_size, args.inner_hidden_size))
|
397 |
+
self.add_mixin(
|
398 |
+
"fp32-softmax",
|
399 |
+
SelfAttentionWithFP32SoftmaxMixin(args.fp16, args.hidden_size, args.num_attention_heads, args.model_parallel_size),
|
400 |
+
)
|
401 |
+
if args.untie_head:
|
402 |
+
self.add_mixin("final-forward", UntieFinalForwardMixin(args.hidden_size, args.vocab_size, args.head_num))
|
403 |
+
else:
|
404 |
+
self.add_mixin("final-forward", FinalForwardMixin())
|
405 |
+
self.add_mixin("non-position-embedding", NonePositionEmbedding())
|
406 |
+
del self.transformer.position_embeddings
|
407 |
+
self.add_mixin("word-embedding", WordEmbedding())
|
408 |
+
self.add_mixin(
|
409 |
+
"rotary-embedding",
|
410 |
+
RotaryEmbeddingMixin(
|
411 |
+
args.fp16,
|
412 |
+
args.hidden_size,
|
413 |
+
args.num_attention_heads,
|
414 |
+
args.model_parallel_size,
|
415 |
+
args.rotary_embedding_2d
|
416 |
+
),
|
417 |
+
)
|
418 |
+
self.get_mixin("glu-deepnorm").reinit()
|
419 |
+
|
420 |
+
@classmethod
|
421 |
+
def add_model_specific_args(cls, parser):
|
422 |
+
group = parser.add_argument_group('ProteinGLMForGeneration', 'ProteinGLMForGeneration Configurations')
|
423 |
+
group.add_argument('--untie-head', action='store_true', help='untie-heads')
|
424 |
+
group.add_argument('--head-num', default=1, type=int, help='head>1')
|
425 |
+
group.add_argument('--infer-type', default=1, type=int, help='1 for Generation')
|
426 |
+
group.add_argument('--rotary-embedding-2d', action='store_true',
|
427 |
+
help='If set, use 2D rotary embedding for ProtenGLM.')
|
428 |
+
return super().add_model_specific_args(parser)
|
msa_input
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PPGPPGPPGKPGANGLSGERGPPGPPGPPG
|
2 |
+
SYEDQNSLLKMICQQVEAIKKEMQELKLNS<M>-AEDHKTILQMICQQVEALKNEMQEMKLNS<M>-AEDQKSLLQMICQQVEALKNEMHEMKLNS
|
3 |
+
MGSSHHHHHHSSGLVPRGSHMGAATPAERDAILLDLVRGQVAAVLGHASGEDIEPGRAFKNLGFDSLTAVELRDRLGAATGHKLPATIVFDYPNPTALAQHLRAAVL
|
4 |
+
MGSSHHHHHHSSGLVPRGSHMGAATPAERDAILLDLVRGQVAAVLGHASGEDIEPGRAFKNLGFDSLTAVELRDRLGAATGHKLPATIVFDYPNPTALAQHLRAAVL<M>-------------ITPSVESLRDLPRSERREALETLVVTEFKTALLMTEQDDLPLDESYFDLGLTSLTVNDLKQRLESLLSREIDGTLLFNSPTVQRLLDHLEEDV-
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.24.1
|
2 |
+
SwissArmyTransformer==0.4.11
|
3 |
+
torch==2.1.0.dev20230822+cu118
|
resources/app_case.png
ADDED
resources/demo.gif
ADDED
Git LFS Details
|
resources/overall_frame.png
ADDED
scripts/cli_sat.sh
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
script_path=$(realpath $0)
|
4 |
+
script_dir=$(dirname $script_path)
|
5 |
+
main_dir=$(dirname $script_dir)
|
6 |
+
|
7 |
+
MP_SIZE=1
|
8 |
+
# MODEL_NAME="MSAGPT-"
|
9 |
+
# MODEL_NAME="MSAGPT-dpo"
|
10 |
+
|
11 |
+
|
12 |
+
SEED=12345
|
13 |
+
MAX_GEN_LENGTH=128
|
14 |
+
MIN_GEN_LENGTH=0
|
15 |
+
|
16 |
+
# BeamSearchStrategy args
|
17 |
+
NUM_BEAMS=4
|
18 |
+
LENGTH_PENALTY=1.0
|
19 |
+
NO_REPEAT_NGRAM=0
|
20 |
+
|
21 |
+
# BaseStrategy args
|
22 |
+
TEMP=0.8
|
23 |
+
TOPK=0
|
24 |
+
TOPP=0.9
|
25 |
+
|
26 |
+
|
27 |
+
PORT=19865
|
28 |
+
|
29 |
+
MODEL_ARGS="--bf16 \
|
30 |
+
--skip-init \
|
31 |
+
--mode finetune \
|
32 |
+
--rotary-embedding-2d"
|
33 |
+
|
34 |
+
# --mode inference \ TODO: sat ds_config bug?
|
35 |
+
|
36 |
+
GENERATION_ARGS="--seed $SEED \
|
37 |
+
--sampling-strategy BaseStrategy \
|
38 |
+
--max-gen-length $MAX_GEN_LENGTH \
|
39 |
+
--min-gen-length $MIN_GEN_LENGTH \
|
40 |
+
--num-beams $NUM_BEAMS \
|
41 |
+
--length-penalty $LENGTH_PENALTY \
|
42 |
+
--no-repeat-ngram-size $NO_REPEAT_NGRAM \
|
43 |
+
--multiline_stream \
|
44 |
+
--temperature $TEMP \
|
45 |
+
--top_k $TOPK \
|
46 |
+
--top_p $TOPP
|
47 |
+
"
|
48 |
+
# --sampling-strategy BeamSearchStrategy \
|
49 |
+
# --no-gap
|
50 |
+
|
51 |
+
|
52 |
+
OPTIONS_NCCL="NCCL_DEBUG=VERSION NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 CUDA_LAUNCH_BLOCKING=0"
|
53 |
+
|
54 |
+
ARGS="${main_dir}/cli_sat.py \
|
55 |
+
$MODEL_ARGS \
|
56 |
+
$GENERATION_ARGS \
|
57 |
+
$*"
|
58 |
+
|
59 |
+
run_cmd="${OPTIONS_NCCL} torchrun --nproc_per_node $MP_SIZE --master_port=$PORT ${ARGS}"
|
60 |
+
echo ${run_cmd}
|
61 |
+
eval ${run_cmd}
|
62 |
+
set +x
|
utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .strategies import AdvancedBaseStrategy, BeamSearchStrategy
|
2 |
+
from .tokenization import proteinglm_tokenizer
|
3 |
+
from .chat import chat_api
|
4 |
+
from .utils import move_cursor_up
|
utils/chat.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import stat
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
from typing import List, Tuple
|
11 |
+
|
12 |
+
import torch.distributed as dist
|
13 |
+
from sat.helpers import print_rank0
|
14 |
+
from sat import mpu, get_args, get_tokenizer
|
15 |
+
from sat.generation.utils import timed_name, generate_continually
|
16 |
+
from sat.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
|
17 |
+
|
18 |
+
from .utils import move_cursor_up, move_cursor_down
|
19 |
+
|
20 |
+
|
21 |
+
def get_masks_and_position_ids(seq, msa_len, max_gen_length, gmask=False):
|
22 |
+
context_length = seq.shape[1]
|
23 |
+
query_len = msa_len
|
24 |
+
max_msa_num = (max_gen_length - 2) // query_len
|
25 |
+
max_gen_length = max_msa_num * query_len + 2
|
26 |
+
tokens = torch.nn.functional.pad(seq, (0, max_gen_length - context_length), mode="constant", value=-1)
|
27 |
+
attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
|
28 |
+
attention_mask.tril_()
|
29 |
+
attention_mask.unsqueeze_(1)
|
30 |
+
attention_mask = (attention_mask < 0.5).bool()
|
31 |
+
# <gMASK> + <SOP>
|
32 |
+
position_ids = np.zeros(max_gen_length, dtype=int)
|
33 |
+
block_position_ids = np.zeros(max_gen_length, dtype=int)
|
34 |
+
pre = 0
|
35 |
+
for msa_idx in range(max_msa_num):
|
36 |
+
position_ids[(1 + pre): (1 + pre + query_len)] = np.arange(query_len, dtype = int)
|
37 |
+
block_position_ids[(1 + pre): (1 + pre + query_len)] = msa_idx
|
38 |
+
pre += query_len
|
39 |
+
position_ids = np.stack((position_ids, block_position_ids), axis=0)
|
40 |
+
position_ids = torch.from_numpy(position_ids).to(tokens.device)
|
41 |
+
position_ids = position_ids.unsqueeze(0)
|
42 |
+
return tokens, attention_mask, position_ids
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def generation_sequence(
|
47 |
+
model,
|
48 |
+
seqs,
|
49 |
+
strategy,
|
50 |
+
max_memory_length=100000,
|
51 |
+
get_masks_and_position_ids=get_masks_and_position_ids,
|
52 |
+
stream=False,
|
53 |
+
mems=None,
|
54 |
+
**kw_args
|
55 |
+
):
|
56 |
+
'''
|
57 |
+
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
|
58 |
+
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
|
59 |
+
cache, should be first mems.shape[1] parts of context_tokens.
|
60 |
+
mems are the first-level citizens here, but we don't assume what is memorized.
|
61 |
+
input mems are used when multi-phase generation.
|
62 |
+
'''
|
63 |
+
assert len(seqs.shape) == 2
|
64 |
+
# building the initial tokens, attention_mask, and position_ids
|
65 |
+
batch_size, context_length = seqs.shape
|
66 |
+
seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
|
67 |
+
tokens = seqs[..., :context_length]
|
68 |
+
# initialize generation
|
69 |
+
counter = context_length # Last fixed index is ``counter''
|
70 |
+
index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
|
71 |
+
num_beams = 1
|
72 |
+
# step-by-step generation
|
73 |
+
while counter < seqs.shape[1] - 1:
|
74 |
+
# Now, we want to generate seq[counter + 1],
|
75 |
+
# token[:, index: counter+1] needs forwarding.
|
76 |
+
# forward
|
77 |
+
tokens = tokens.reshape(batch_size * num_beams, -1)
|
78 |
+
mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
|
79 |
+
model.eval()
|
80 |
+
with torch.no_grad():
|
81 |
+
logits, *output_per_layers = model(
|
82 |
+
tokens[:, index:],
|
83 |
+
position_ids[..., index: counter],
|
84 |
+
attention_mask[..., index: counter, :counter], # TODO memlen
|
85 |
+
mems=mems,
|
86 |
+
**kw_args
|
87 |
+
)
|
88 |
+
mem_kv = [o['mem_kv'] for o in output_per_layers]
|
89 |
+
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
|
90 |
+
logits = logits[:, -1]
|
91 |
+
index = counter
|
92 |
+
counter += 1
|
93 |
+
logits = logits.reshape(batch_size, num_beams, -1)
|
94 |
+
tokens = tokens.reshape(batch_size, num_beams, -1)
|
95 |
+
mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
|
96 |
+
tokens, mems = strategy.forward(logits, tokens, mems)
|
97 |
+
if len(tokens.shape) == 3 and num_beams == 1:
|
98 |
+
num_beams = tokens.shape[1]
|
99 |
+
position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, 2, -1).reshape(batch_size * num_beams, 2, -1)
|
100 |
+
attention_mask_shape = attention_mask.shape[-3:]
|
101 |
+
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
|
102 |
+
batch_size * num_beams, *attention_mask_shape)
|
103 |
+
if strategy.is_done:
|
104 |
+
break
|
105 |
+
return strategy.finalize(tokens, mems)
|
106 |
+
|
107 |
+
|
108 |
+
def stream_generation_sequence(
|
109 |
+
model,
|
110 |
+
seqs,
|
111 |
+
strategy,
|
112 |
+
max_memory_length=100000,
|
113 |
+
get_masks_and_position_ids=get_masks_and_position_ids,
|
114 |
+
stream=False,
|
115 |
+
mems=None,
|
116 |
+
**kw_args
|
117 |
+
):
|
118 |
+
'''
|
119 |
+
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
|
120 |
+
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
|
121 |
+
cache, should be first mems.shape[1] parts of context_tokens.
|
122 |
+
mems are the first-level citizens here, but we don't assume what is memorized.
|
123 |
+
input mems are used when multi-phase generation.
|
124 |
+
'''
|
125 |
+
assert len(seqs.shape) == 2
|
126 |
+
# building the initial tokens, attention_mask, and position_ids
|
127 |
+
batch_size, context_length = seqs.shape
|
128 |
+
seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
|
129 |
+
tokens = seqs[..., :context_length]
|
130 |
+
# initialize generation
|
131 |
+
counter = context_length # Last fixed index is ``counter''
|
132 |
+
index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
|
133 |
+
num_beams = 1
|
134 |
+
# step-by-step generation
|
135 |
+
while counter < seqs.shape[1] - 1:
|
136 |
+
# Now, we want to generate seq[counter + 1],
|
137 |
+
# token[:, index: counter+1] needs forwarding.
|
138 |
+
# forward
|
139 |
+
tokens = tokens.reshape(batch_size * num_beams, -1)
|
140 |
+
mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
|
141 |
+
model.eval()
|
142 |
+
with torch.no_grad():
|
143 |
+
logits, *output_per_layers = model(
|
144 |
+
tokens[:, index:],
|
145 |
+
position_ids[..., index: counter],
|
146 |
+
attention_mask[..., index: counter, :counter], # TODO memlen
|
147 |
+
mems=mems,
|
148 |
+
**kw_args
|
149 |
+
)
|
150 |
+
mem_kv = [o['mem_kv'] for o in output_per_layers]
|
151 |
+
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
|
152 |
+
logits = logits[:, -1]
|
153 |
+
index = counter
|
154 |
+
counter += 1
|
155 |
+
logits = logits.reshape(batch_size, num_beams, -1)
|
156 |
+
tokens = tokens.reshape(batch_size, num_beams, -1)
|
157 |
+
mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
|
158 |
+
tokens, mems = strategy.forward(logits, tokens, mems, is_first=False)
|
159 |
+
if len(tokens.shape) == 3 and num_beams == 1:
|
160 |
+
num_beams = tokens.shape[1]
|
161 |
+
position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, 2, -1).reshape(batch_size * num_beams, 2, -1)
|
162 |
+
attention_mask_shape = attention_mask.shape[-3:]
|
163 |
+
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
|
164 |
+
batch_size * num_beams, *attention_mask_shape)
|
165 |
+
yield tokens, mems
|
166 |
+
if strategy.is_done:
|
167 |
+
break
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
def autoregressive_sampling(args, raw_text: str, model, tokenizer, strategy, stream=False) -> Tuple[List[str], List[str], List[List[str]]]:
|
172 |
+
# add MASK
|
173 |
+
generation_mask = "[gMASK]"
|
174 |
+
seq = []
|
175 |
+
msa_len = len(raw_text[0]) + 1
|
176 |
+
seq += [tokenizer.get_command(generation_mask)] + [tokenizer.get_command("sop")]
|
177 |
+
for each in raw_text:
|
178 |
+
seq += tokenizer.tokenize(each) + [tokenizer.get_command('<M>')]
|
179 |
+
|
180 |
+
output_list = [seq]
|
181 |
+
num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
|
182 |
+
seq = output_list[0]
|
183 |
+
# detect mask position
|
184 |
+
mask_token = tokenizer.get_command(generation_mask)
|
185 |
+
mask_position = seq.index(mask_token)
|
186 |
+
|
187 |
+
last_pos, answers, blanks, output_list = (
|
188 |
+
[0] * num_output,
|
189 |
+
["" for _ in range(num_output)],
|
190 |
+
[[] for _ in range(num_output)],
|
191 |
+
[]
|
192 |
+
)
|
193 |
+
icl_msas = len(raw_text)
|
194 |
+
input_seq = torch.tensor(
|
195 |
+
[seq],
|
196 |
+
dtype = torch.long,
|
197 |
+
device=args.device,
|
198 |
+
)
|
199 |
+
if args.stream_chat:
|
200 |
+
if args.chinese:
|
201 |
+
print(f"{'生成的MSA'.center(20, '*')}", flush=True)
|
202 |
+
else:
|
203 |
+
print(f"{'Virtual MSA'.center(20, '*')}", flush=True)
|
204 |
+
output_stream = stream_generation_sequence(
|
205 |
+
model = model,
|
206 |
+
seqs = input_seq,
|
207 |
+
strategy=strategy,
|
208 |
+
get_masks_and_position_ids=partial(
|
209 |
+
get_masks_and_position_ids,
|
210 |
+
msa_len = msa_len,
|
211 |
+
max_gen_length=args.max_gen_length,
|
212 |
+
gmask=True
|
213 |
+
)
|
214 |
+
)
|
215 |
+
offset = -1
|
216 |
+
for tmp_res, mems in output_stream:
|
217 |
+
if isinstance(tmp_res, torch.Tensor):
|
218 |
+
output = tmp_res.tolist()
|
219 |
+
output_list = output[0]
|
220 |
+
for i in range(len(output_list)):
|
221 |
+
output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
|
222 |
+
bog = output.index(tokenizer.get_command("sop"))
|
223 |
+
try:
|
224 |
+
unfinished = output.index(-1)
|
225 |
+
except ValueError:
|
226 |
+
unfinished = len(output)
|
227 |
+
output_list[i] = output[:mask_position] + output[bog + 1 : unfinished]
|
228 |
+
for i, output in enumerate(output_list):
|
229 |
+
if output[-1] == tokenizer.get_command("eos"):
|
230 |
+
output = output[:-1]
|
231 |
+
answers[i] = tokenizer.detokenize(output)
|
232 |
+
tmp_ret = answers[0] # only support streaming output first line.
|
233 |
+
if mpu.get_model_parallel_rank() == 0:
|
234 |
+
if not args.multiline_stream:
|
235 |
+
vit_msa = tmp_ret[offset if offset>0 else -1:]
|
236 |
+
print(vit_msa, end='', flush=True)
|
237 |
+
offset = len(tmp_ret)
|
238 |
+
else:
|
239 |
+
print_len = 0
|
240 |
+
vit_msa = tmp_ret.split('[<M>]')[icl_msas:]
|
241 |
+
vit_msa = [_ for _ in vit_msa if len(_) > 0]
|
242 |
+
for _ in vit_msa:
|
243 |
+
print(_)
|
244 |
+
print_len += 1
|
245 |
+
move_cursor_up(print_len)
|
246 |
+
|
247 |
+
move_cursor_down(print_len)
|
248 |
+
print('\n')
|
249 |
+
output = strategy.finalize(tmp_res, mems)[0]
|
250 |
+
else:
|
251 |
+
output, _ = generation_sequence(
|
252 |
+
model = model,
|
253 |
+
seqs = input_seq,
|
254 |
+
strategy=strategy,
|
255 |
+
get_masks_and_position_ids=partial(
|
256 |
+
get_masks_and_position_ids,
|
257 |
+
msa_len = msa_len,
|
258 |
+
max_gen_length=args.max_gen_length,
|
259 |
+
gmask=True
|
260 |
+
)
|
261 |
+
)
|
262 |
+
last_pos, answers, blanks, output_list = (
|
263 |
+
[0] * num_output,
|
264 |
+
["" for _ in range(num_output)],
|
265 |
+
[[] for _ in range(num_output)],
|
266 |
+
[]
|
267 |
+
)
|
268 |
+
if isinstance(output, torch.Tensor): # different strategies
|
269 |
+
output = output.tolist()
|
270 |
+
output = output[0] # batch_size = 1
|
271 |
+
output_list.extend(output)
|
272 |
+
# clip -1s and fill back generated things into seq
|
273 |
+
for i in range(len(output_list)):
|
274 |
+
output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
|
275 |
+
try:
|
276 |
+
unfinished = output.index(-1)
|
277 |
+
except ValueError:
|
278 |
+
unfinished = len(output)
|
279 |
+
# if output[unfinished - 1] in strategy.end_tokens:
|
280 |
+
# unfinished -= 1
|
281 |
+
bog = output.index(tokenizer.get_command("sop"))
|
282 |
+
|
283 |
+
prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
|
284 |
+
blank = tokenizer.detokenize(output[bog + 1 : unfinished])
|
285 |
+
blanks[i].append(blank)
|
286 |
+
last_pos[i] = mask_position + unfinished - (bog + 1)
|
287 |
+
output_list[i] = output[:mask_position] + output[bog + 1 : unfinished]
|
288 |
+
|
289 |
+
|
290 |
+
for i, output in enumerate(output_list):
|
291 |
+
if output[-1] == tokenizer.get_command("eos"):
|
292 |
+
output = output[:-1]
|
293 |
+
answers[i] = tokenizer.detokenize(output)
|
294 |
+
return answers
|
295 |
+
|
296 |
+
|
297 |
+
def offline_generation(args, temp, top_p, top_k, func):
|
298 |
+
os.makedirs(args.output_path, exist_ok=True)
|
299 |
+
with open(args.input_source, 'r', encoding="utf-8") as fin:
|
300 |
+
inputs = fin.readlines()
|
301 |
+
output_path = os.path.join(args.output_path, f"tmp_{temp}_p_{top_p}_k_{top_k}")
|
302 |
+
fin = open(output_path, 'w')
|
303 |
+
start_time = time.time()
|
304 |
+
for line_no, raw_text in enumerate(inputs):
|
305 |
+
if line_no % mpu.get_data_parallel_world_size() != mpu.get_data_parallel_rank():
|
306 |
+
continue
|
307 |
+
rk = dist.get_rank()
|
308 |
+
raw_text = raw_text.strip()
|
309 |
+
raw_text = raw_text.split('<M>')
|
310 |
+
main_seq = raw_text[0]
|
311 |
+
|
312 |
+
msa_len = len(main_seq) + 1
|
313 |
+
icl_msas = len(raw_text)
|
314 |
+
require_min_gen_length = msa_len * (icl_msas + 1) + 2
|
315 |
+
if args.max_gen_length < require_min_gen_length:
|
316 |
+
args.max_gen_length = require_min_gen_length # at least generate 1 msa.
|
317 |
+
|
318 |
+
if mpu.get_model_parallel_rank() == 0:
|
319 |
+
print(f'Processing No. {line_no} on model group {rk} input main seq: "{main_seq}" few-shot prompt: "{"<M>".join(raw_text[1:])}"')
|
320 |
+
if len(raw_text) == 0:
|
321 |
+
continue
|
322 |
+
ret = func(raw_text)
|
323 |
+
if mpu.get_model_parallel_rank() == 0:
|
324 |
+
if args.print_all_beams:
|
325 |
+
for idx, vit_msa in enumerate(ret):
|
326 |
+
vit_msa = vit_msa.split('[<M>]')[icl_msas:]
|
327 |
+
vit_msa = [_ for _ in vit_msa if len(_) > 0]
|
328 |
+
vit_msa_len = len(vit_msa)
|
329 |
+
vit_msa_str = '<M>'.join(vit_msa)
|
330 |
+
print('Beam: {} #Vitural Length:{} | MSA: "{}" | (Temp, P, K)=({}, {}, {}) | Taken time {:.2f}'.format(idx, vit_msa_len, vit_msa_str, temp, top_p, top_k, time.time() - start_time), flush=True)
|
331 |
+
else:
|
332 |
+
vit_msa = ret[0]
|
333 |
+
vit_msa = vit_msa.split('[<M>]')[icl_msas:]
|
334 |
+
vit_msa = [_ for _ in vit_msa if len(_) > 0]
|
335 |
+
vit_msa_len = len(vit_msa)
|
336 |
+
vit_msa_str = '<M>'.join(vit_msa)
|
337 |
+
fin.write(f"{vit_msa_str}"+'\n')
|
338 |
+
print('#Vitural Length:{} | MSA: "{}" | (Temp, P, K)=({}, {}, {}) | Taken time {:.2f}'.format(vit_msa_len, vit_msa_str, temp, top_p, top_k, time.time() - start_time), flush=True)
|
339 |
+
print()
|
340 |
+
fin.flush()
|
341 |
+
dist.barrier()
|
342 |
+
fin.close()
|
343 |
+
|
344 |
+
|
345 |
+
def online_generation(args, query, temp, top_p, top_k, func):
|
346 |
+
raw_text = query.strip()
|
347 |
+
raw_text = raw_text.split('<M>')
|
348 |
+
main_seq = raw_text[0]
|
349 |
+
msa_len = len(main_seq) + 1
|
350 |
+
icl_msas = len(raw_text)
|
351 |
+
require_min_gen_length = msa_len * (icl_msas + 1) + 2
|
352 |
+
if args.max_gen_length < require_min_gen_length:
|
353 |
+
args.max_gen_length = require_min_gen_length # at least generate 1 msa.
|
354 |
+
ret = func(raw_text)
|
355 |
+
response = []
|
356 |
+
if mpu.get_model_parallel_rank() == 0:
|
357 |
+
for idx, vit_msa in enumerate(ret):
|
358 |
+
vit_msa = vit_msa.split('[<M>]')[icl_msas:]
|
359 |
+
vit_msa = [_ for _ in vit_msa if len(_) > 0]
|
360 |
+
response.append(vit_msa)
|
361 |
+
return response
|
362 |
+
|
363 |
+
|
364 |
+
def chat_api(args, model, tokenizer, strategy, query=None): # TODO: Steam chat
|
365 |
+
if args.input_source == 'chat':
|
366 |
+
assert query is not None
|
367 |
+
ret = online_generation(args, query, temp=args.temperature, top_p = args.top_p, top_k = args.top_k, func = partial(autoregressive_sampling, args, model = model, tokenizer = tokenizer, strategy = strategy))
|
368 |
+
return ret
|
369 |
+
else:
|
370 |
+
assert not args.stream_chat, "Offline Generation don't support streaming output."
|
371 |
+
offline_generation(args, temp=args.temperature, top_p = args.top_p, top_k = args.top_k, func = partial(autoregressive_sampling, args, model = model, tokenizer = tokenizer, strategy = strategy))
|
utils/strategies.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from sat.generation.sampling_strategies.base_strategy import top_k_logits
|
5 |
+
from sat.mpu.initialize import get_model_parallel_world_size, get_model_parallel_src_rank, get_model_parallel_group
|
6 |
+
|
7 |
+
class AdvancedBaseStrategy:
|
8 |
+
def __init__(self, batch_size, invalid_slices=[], temperature=1., no_repeat_ngram_size = 0, top_k=200, eps=1e-4, top_p=0.0, min_gen_length=1, end_tokens=None):
|
9 |
+
self.batch_size = batch_size
|
10 |
+
self.invalid_slices = invalid_slices
|
11 |
+
self.temperature = temperature
|
12 |
+
self.topk = top_k
|
13 |
+
self.top_p = top_p
|
14 |
+
self.eps = eps
|
15 |
+
self.min_gen_length = min_gen_length
|
16 |
+
self.ngram=no_repeat_ngram_size
|
17 |
+
if end_tokens is None:
|
18 |
+
end_tokens = []
|
19 |
+
self.end_tokens = end_tokens
|
20 |
+
self.length_generated = 0
|
21 |
+
self.cached_beam_ngram_bans = [{} for _ in range(self.batch_size)]
|
22 |
+
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
|
23 |
+
self._init_cache()
|
24 |
+
|
25 |
+
@property
|
26 |
+
def is_done(self) -> bool:
|
27 |
+
return self._is_done.all()
|
28 |
+
|
29 |
+
def _init_cache(self):
|
30 |
+
self.length_generated = 0
|
31 |
+
self.cached_beam_ngram_bans = [[{}] for _ in range(self.batch_size)]
|
32 |
+
self._is_done = np.zeros(self.batch_size, dtype=bool)
|
33 |
+
|
34 |
+
|
35 |
+
def forward(self, logits, tokens, mems, is_first = False, temperature=None):
|
36 |
+
# print(is_first)
|
37 |
+
batch_size, num_beam, seq_len = tokens.shape
|
38 |
+
seq_len = tokens.shape[-1]
|
39 |
+
if temperature is None:
|
40 |
+
temperature = self.temperature
|
41 |
+
logits = logits / temperature
|
42 |
+
if self.min_gen_length > self.length_generated:
|
43 |
+
for end_token in self.end_tokens:
|
44 |
+
logits[..., end_token] = -65504
|
45 |
+
for invalid_slice in self.invalid_slices:
|
46 |
+
logits[..., invalid_slice] = -65504
|
47 |
+
if self.ngram > 0 and seq_len > self.ngram:
|
48 |
+
for batch_idx in range(batch_size):
|
49 |
+
for i in range(num_beam):
|
50 |
+
ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
|
51 |
+
for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
|
52 |
+
logits[batch_idx, i, banned_index] = -65504
|
53 |
+
logits = logits.view(-1, logits.size(-1))
|
54 |
+
logits = top_k_logits(logits, self.topk, self.top_p)
|
55 |
+
probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
|
56 |
+
|
57 |
+
pred = torch.multinomial(probs, num_samples=1)
|
58 |
+
for i in range(self.batch_size):
|
59 |
+
if i >= batch_size:
|
60 |
+
self._is_done[i] = True
|
61 |
+
elif self._is_done[i]:
|
62 |
+
pred[i] = -1
|
63 |
+
elif pred[i].item() in self.end_tokens:
|
64 |
+
self._is_done[i] = True
|
65 |
+
|
66 |
+
if self.ngram > 0:
|
67 |
+
for batch_idx in range(batch_size):
|
68 |
+
bans_continue = []
|
69 |
+
for i in range(num_beam):
|
70 |
+
bans = self.cached_beam_ngram_bans[batch_idx][i].copy()
|
71 |
+
ngram_prefix = tuple(tokens[batch_idx, i, -(self.ngram - 1):].tolist())
|
72 |
+
bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (pred[batch_idx],)
|
73 |
+
bans_continue.append(bans)
|
74 |
+
self.cached_beam_ngram_bans[batch_idx] = bans_continue
|
75 |
+
tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
|
76 |
+
self.length_generated += 1
|
77 |
+
|
78 |
+
return tokens, mems
|
79 |
+
|
80 |
+
def finalize(self, tokens, mems):
|
81 |
+
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
|
82 |
+
self._init_cache()
|
83 |
+
return tokens, mems
|
84 |
+
|
85 |
+
|
86 |
+
class BeamSearchStrategy:
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
batch_size,
|
90 |
+
num_beams,
|
91 |
+
length_penalty=1.0,
|
92 |
+
consider_end=False,
|
93 |
+
end_tokens=[],
|
94 |
+
invalid_slices=[],
|
95 |
+
no_repeat_ngram_size=0,
|
96 |
+
min_gen_length=0,
|
97 |
+
deterministic=False,
|
98 |
+
):
|
99 |
+
self.batch_size = batch_size
|
100 |
+
self.num_beams = num_beams
|
101 |
+
self.length_penalty = length_penalty
|
102 |
+
self.end_tokens = end_tokens
|
103 |
+
self.ngram = no_repeat_ngram_size
|
104 |
+
self.min_gen_length = min_gen_length
|
105 |
+
self.invalid_slices = invalid_slices
|
106 |
+
self.consider_end = consider_end
|
107 |
+
self.deterministic = deterministic
|
108 |
+
self._init_cache()
|
109 |
+
|
110 |
+
def _init_cache(self):
|
111 |
+
self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors
|
112 |
+
self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors
|
113 |
+
self.cached_beam_scores = 0 # [batch_size]
|
114 |
+
self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
|
115 |
+
self.length_generated = 0
|
116 |
+
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
|
117 |
+
|
118 |
+
def _add_end_beams(self, score, beam, batch_idx):
|
119 |
+
score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
|
120 |
+
for i in range(len(self.end_beams[batch_idx]), -1, -1):
|
121 |
+
if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
|
122 |
+
break
|
123 |
+
self.end_beams[batch_idx].insert(i, beam)
|
124 |
+
self.end_beams_penalized_scores[batch_idx].insert(i, score)
|
125 |
+
|
126 |
+
self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
|
127 |
+
self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
|
128 |
+
|
129 |
+
@property
|
130 |
+
def is_done(self) -> bool:
|
131 |
+
return self._is_done.all()
|
132 |
+
|
133 |
+
def forward(self, logits, tokens, mems):
|
134 |
+
batch_size, num_beams, vocab_size = logits.shape
|
135 |
+
seq_len = tokens.shape[-1]
|
136 |
+
logits = logits.float()
|
137 |
+
for invalid_slice in self.invalid_slices:
|
138 |
+
logits[..., invalid_slice] = -65504
|
139 |
+
if self.min_gen_length > self.length_generated:
|
140 |
+
for end_token in self.end_tokens:
|
141 |
+
logits[..., end_token] = -65504
|
142 |
+
if self.ngram > 0 and seq_len > self.ngram:
|
143 |
+
for batch_idx in range(batch_size):
|
144 |
+
for i in range(num_beams):
|
145 |
+
ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
|
146 |
+
for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
|
147 |
+
logits[batch_idx, i, banned_index] = -65504
|
148 |
+
|
149 |
+
next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
|
150 |
+
prev_scores = self.cached_beam_scores
|
151 |
+
if isinstance(prev_scores, torch.Tensor):
|
152 |
+
prev_scores = prev_scores[..., None].expand_as(next_token_scores)
|
153 |
+
next_token_scores = next_token_scores + prev_scores
|
154 |
+
|
155 |
+
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
156 |
+
|
157 |
+
probs = F.softmax(next_token_scores, dim=-1)
|
158 |
+
if num_beams < self.num_beams: # First token
|
159 |
+
probs = probs[..., :vocab_size]
|
160 |
+
if self.deterministic:
|
161 |
+
next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
|
162 |
+
else:
|
163 |
+
next_tokens = torch.multinomial(
|
164 |
+
probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
|
165 |
+
) # [2*nb]
|
166 |
+
next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
|
167 |
+
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
|
168 |
+
next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
|
169 |
+
|
170 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
|
171 |
+
next_tokens = next_tokens % vocab_size
|
172 |
+
|
173 |
+
# select out end beams or continue beams
|
174 |
+
beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
|
175 |
+
for batch_idx in range(batch_size):
|
176 |
+
beam_continue = []
|
177 |
+
scores_continue = []
|
178 |
+
bans_continue = []
|
179 |
+
mems_contiue = []
|
180 |
+
for i in range(len(next_tokens[batch_idx])):
|
181 |
+
beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
|
182 |
+
if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
|
183 |
+
self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
|
184 |
+
elif len(beam_continue) < self.num_beams:
|
185 |
+
beam_continue.append(beam)
|
186 |
+
mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
|
187 |
+
# update caches
|
188 |
+
scores_continue.append(next_token_scores[batch_idx, i])
|
189 |
+
if self.ngram > 0:
|
190 |
+
bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
|
191 |
+
# TODO ngram=1
|
192 |
+
ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
|
193 |
+
bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
|
194 |
+
bans_continue.append(bans)
|
195 |
+
else:
|
196 |
+
break
|
197 |
+
beam_continue_batch.append(torch.stack(beam_continue))
|
198 |
+
mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
|
199 |
+
score_continue_batch.append(scores_continue)
|
200 |
+
self.cached_beam_ngram_bans[batch_idx] = bans_continue
|
201 |
+
tokens = torch.stack(beam_continue_batch)
|
202 |
+
mems = torch.stack(mems_continue_batch, dim=1)
|
203 |
+
self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
|
204 |
+
self.length_generated += 1
|
205 |
+
for batch_idx in range(self.batch_size):
|
206 |
+
if batch_idx >= batch_size:
|
207 |
+
self._is_done[batch_idx] = True
|
208 |
+
elif (
|
209 |
+
len(self.end_beams[batch_idx]) == self.num_beams
|
210 |
+
and self.end_beams_penalized_scores[batch_idx][-1]
|
211 |
+
>= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
|
212 |
+
): # We're done if none of current tokens will better than the worst in end_beams
|
213 |
+
self._is_done[batch_idx] = True
|
214 |
+
|
215 |
+
return tokens, mems
|
216 |
+
|
217 |
+
def finalize(self, tokens, mems):
|
218 |
+
if self.consider_end:
|
219 |
+
batch_size, num_beams = tokens.shape[:2]
|
220 |
+
for batch_idx in range(batch_size):
|
221 |
+
if not self._is_done[batch_idx]:
|
222 |
+
for i in range(num_beams):
|
223 |
+
self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
|
224 |
+
mems = None
|
225 |
+
ret = self.end_beams[:batch_size]
|
226 |
+
else:
|
227 |
+
ret = tokens
|
228 |
+
self._init_cache()
|
229 |
+
return ret, mems
|
utils/tokenization.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence, Tuple, List, Union
|
2 |
+
import itertools
|
3 |
+
|
4 |
+
class ResidueLevelTokenizer:
|
5 |
+
"""
|
6 |
+
Tokenizer for Protein Residue Level Tokenization.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(self, **kwargs):
|
10 |
+
super(ResidueLevelTokenizer, self).__init__()
|
11 |
+
self.pad_tok = ['[pad]']
|
12 |
+
self.all_toks = self.pad_tok
|
13 |
+
self._tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
|
14 |
+
self.all_toks.extend(self._tokens)
|
15 |
+
self._special_tokens = ['MASK', 'gMASK', 'sMASK', 'eod', 'sop', 'eop', '</s>', '<M>']
|
16 |
+
self.set_special_tokens(self._special_tokens)
|
17 |
+
self.special_tokens['eos']=self.special_tokens['</s>']
|
18 |
+
self.special_tokens['tMASK']=self.special_tokens['MASK']
|
19 |
+
|
20 |
+
self.all_toks.extend(self._special_tokens)
|
21 |
+
self._vocab = {t: i for i, t in enumerate(self.all_toks)}
|
22 |
+
self.command_token = {'[tMASK]': 'tMASK', '[MASK]':'MASK', '[gMASK]': 'gMASK', '[sMASK]':'sMASK'}
|
23 |
+
# print('Building vocab.: {}'.format(self._vocab))
|
24 |
+
# print('Special_tokens: {}'.format(self.special_tokens))
|
25 |
+
# print('All tokens: {}'.format(self.all_toks))
|
26 |
+
|
27 |
+
def pad_id(self):
|
28 |
+
return self._vocab['[pad]']
|
29 |
+
|
30 |
+
def set_special_tokens(self, special_tokens):
|
31 |
+
"""Add a list of additional tokens to the encoder.
|
32 |
+
The additional tokens are indexed starting from the last index of the
|
33 |
+
current vocabulary in the order of the `special_tokens` list.
|
34 |
+
"""
|
35 |
+
if not special_tokens:
|
36 |
+
self.special_tokens = {}
|
37 |
+
self.special_tokens_decoder = {}
|
38 |
+
return
|
39 |
+
self.special_tokens = dict((tok, len(self.all_toks) + i) for i, tok in enumerate(special_tokens))
|
40 |
+
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
|
41 |
+
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self._vocab)
|
45 |
+
|
46 |
+
|
47 |
+
def EncodeAsIds(self, text, process_fn=None):
|
48 |
+
"""convert sequence to idx"""
|
49 |
+
processed_text = text
|
50 |
+
if process_fn is not None:
|
51 |
+
processed_text = process_fn(processed_text)
|
52 |
+
processed_text = str(processed_text)
|
53 |
+
tokens = [self.TokenToId(c) for c in processed_text]
|
54 |
+
return tokens
|
55 |
+
|
56 |
+
def IdToToken(self, idx):
|
57 |
+
if idx == 0:
|
58 |
+
return '[pad]'
|
59 |
+
elif idx in self.special_tokens_decoder:
|
60 |
+
return f"[{self.special_tokens_decoder[idx]}]"
|
61 |
+
else:
|
62 |
+
try:
|
63 |
+
tok = self.all_toks[idx]
|
64 |
+
except:
|
65 |
+
tok = '*'
|
66 |
+
return tok
|
67 |
+
def TokenToId(self, token):
|
68 |
+
if token == '[pad]':
|
69 |
+
return 0
|
70 |
+
elif token in self.special_tokens:
|
71 |
+
return self.special_tokens[token]
|
72 |
+
else:
|
73 |
+
return self._vocab[token]
|
74 |
+
|
75 |
+
def DecodeIds(self, Ids):
|
76 |
+
return ''.join([self.IdToToken(tok) for tok in Ids])
|
77 |
+
|
78 |
+
def _tokenize(self, text) -> str:
|
79 |
+
return text.split()
|
80 |
+
|
81 |
+
def tokenize(self, text, **kwargs) -> List[str]:
|
82 |
+
"""
|
83 |
+
Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
|
84 |
+
Converts a string in a sequence of tokens, using the tokenizer.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
text (:obj:`str`):
|
88 |
+
The sequence to be encoded.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
:obj:`List[str]`: The list of tokens.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def split_on_token(tok, text):
|
95 |
+
result = []
|
96 |
+
split_text = text.split(tok)
|
97 |
+
for i, sub_text in enumerate(split_text):
|
98 |
+
# AddedToken can control whitespace stripping around them.
|
99 |
+
# We use them for GPT2 and Roberta to have different behavior depending on the special token
|
100 |
+
# Cf. https://github.com/huggingface/transformers/pull/2778
|
101 |
+
# and https://github.com/huggingface/transformers/issues/3788
|
102 |
+
# We strip left and right by default
|
103 |
+
if i < len(split_text) - 1:
|
104 |
+
sub_text = sub_text.rstrip()
|
105 |
+
if i > 0:
|
106 |
+
sub_text = sub_text.lstrip()
|
107 |
+
|
108 |
+
if i == 0 and not sub_text:
|
109 |
+
result.append(tok)
|
110 |
+
elif i == len(split_text) - 1:
|
111 |
+
if sub_text:
|
112 |
+
result.append(sub_text)
|
113 |
+
else:
|
114 |
+
pass
|
115 |
+
else:
|
116 |
+
if sub_text:
|
117 |
+
result.append(sub_text)
|
118 |
+
result.append(tok)
|
119 |
+
return result
|
120 |
+
|
121 |
+
def split_on_tokens(tok_list, text):
|
122 |
+
if not text.strip():
|
123 |
+
return []
|
124 |
+
|
125 |
+
tokenized_text = []
|
126 |
+
text_list = [text]
|
127 |
+
for tok in tok_list:
|
128 |
+
tokenized_text = []
|
129 |
+
for sub_text in text_list:
|
130 |
+
if sub_text not in self._tokens:
|
131 |
+
tokenized_text.extend(split_on_token(tok, sub_text))
|
132 |
+
else:
|
133 |
+
tokenized_text.append(sub_text)
|
134 |
+
text_list = tokenized_text
|
135 |
+
|
136 |
+
return list(
|
137 |
+
itertools.chain.from_iterable(
|
138 |
+
(
|
139 |
+
self._tokenize(token)
|
140 |
+
if token not in self.all_toks
|
141 |
+
else [token]
|
142 |
+
for token in tokenized_text
|
143 |
+
)
|
144 |
+
)
|
145 |
+
)
|
146 |
+
no_split_token = self.all_toks
|
147 |
+
tokenized_text = split_on_tokens(no_split_token, text)
|
148 |
+
return self.convert_tokens_to_ids(tokenized_text)
|
149 |
+
|
150 |
+
def convert_tokens_to_ids(self, tokens):
|
151 |
+
"""Converts a sequence of tokens into ids using the vocab."""
|
152 |
+
ids = []
|
153 |
+
# print_rank_0(tokens)
|
154 |
+
# print_rank_0(self.vocab)
|
155 |
+
for token in tokens:
|
156 |
+
ids.append(self.TokenToId(token))
|
157 |
+
return ids
|
158 |
+
|
159 |
+
|
160 |
+
class proteinglm_tokenizer:
|
161 |
+
"""
|
162 |
+
Protein Tokenizer based on Residue level tokenizer
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self):
|
166 |
+
name = 'ProteinTokenizer'
|
167 |
+
self.tokenizer = ResidueLevelTokenizer()
|
168 |
+
self.special_tokens = self.tokenizer.special_tokens
|
169 |
+
|
170 |
+
|
171 |
+
def IdToToken(self, idx):
|
172 |
+
return self.tokenizer.IdToToken(idx)
|
173 |
+
|
174 |
+
def TokenToId(self, token):
|
175 |
+
return self.tokenizer.TokenToId(token)
|
176 |
+
|
177 |
+
@property
|
178 |
+
def vocab_size(self):
|
179 |
+
return len(self.tokenizer)
|
180 |
+
|
181 |
+
def decode(self, token_ids):
|
182 |
+
return self.tokenizer.DecodeIds([token_ids])
|
183 |
+
|
184 |
+
@property
|
185 |
+
def eod(self):
|
186 |
+
return self.tokenizer.get_special_token('eos')
|
187 |
+
|
188 |
+
def detokenize(self, Ids, type_token=False):
|
189 |
+
new_tokens = self.tokenizer.DecodeIds(Ids)
|
190 |
+
return new_tokens
|
191 |
+
|
192 |
+
def tokenize(self, text):
|
193 |
+
ids = self.tokenizer.tokenize(text)
|
194 |
+
return ids
|
195 |
+
|
196 |
+
@property
|
197 |
+
def vocab(self):
|
198 |
+
return self.tokenizer._vocab
|
199 |
+
|
200 |
+
@property
|
201 |
+
def inv_vocab(self):
|
202 |
+
return {v:k for k, v in self.tokenizer._vocab.items()}
|
203 |
+
|
204 |
+
@property
|
205 |
+
def get_pad_id(self):
|
206 |
+
return self.tokenizer.pad_id
|
207 |
+
|
208 |
+
|
209 |
+
def get_command(self, token):
|
210 |
+
tok = token
|
211 |
+
if token in self.tokenizer.command_token:
|
212 |
+
tok = self.tokenizer.command_token[token]
|
213 |
+
return self.tokenizer.special_tokens[tok]
|
utils/utils.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def move_cursor_up(n):
|
2 |
+
# ANSI escape code to move cursor up by n lines
|
3 |
+
print(f"\033[{n}A", end='')
|
4 |
+
|
5 |
+
def move_cursor_down(n):
|
6 |
+
# ANSI escape code to move cursor down by n lines
|
7 |
+
print(f"\033[{n}B", end='')
|