Commit
•
cda1e7d
0
Parent(s):
add scripts
Browse files- .gitattributes +257 -0
- convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py +165 -0
- script.py +128 -0
.gitattributes
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_moe_54b/checkpoint_2_300000-rank-67.pt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
model_moe_54b/checkpoint_2_300000-rank-77.pt filter=lfs diff=lfs merge=lfs -text
|
3 |
+
model_moe_54b/checkpoint_2_300000-rank-78.pt filter=lfs diff=lfs merge=lfs -text
|
4 |
+
model_moe_54b/checkpoint_2_300000-rank-90.pt filter=lfs diff=lfs merge=lfs -text
|
5 |
+
model_moe_54b/checkpoint_2_300000-rank-115.pt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
model_moe_54b/checkpoint_2_300000-rank-170.pt filter=lfs diff=lfs merge=lfs -text
|
7 |
+
model_moe_54b/checkpoint_2_300000-rank-201.pt filter=lfs diff=lfs merge=lfs -text
|
8 |
+
model_moe_54b/checkpoint_2_300000-rank-16.pt filter=lfs diff=lfs merge=lfs -text
|
9 |
+
model_moe_54b/checkpoint_2_300000-rank-213.pt filter=lfs diff=lfs merge=lfs -text
|
10 |
+
model_moe_54b/checkpoint_2_300000-rank-212.pt filter=lfs diff=lfs merge=lfs -text
|
11 |
+
model_moe_54b/checkpoint_2_300000-rank-44.pt filter=lfs diff=lfs merge=lfs -text
|
12 |
+
model_moe_54b/checkpoint_2_300000-rank-48.pt filter=lfs diff=lfs merge=lfs -text
|
13 |
+
model_moe_54b/checkpoint_2_300000-rank-63.pt filter=lfs diff=lfs merge=lfs -text
|
14 |
+
model_moe_54b/checkpoint_2_300000-rank-81.pt filter=lfs diff=lfs merge=lfs -text
|
15 |
+
model_moe_54b/checkpoint_2_300000-rank-134.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
+
model_moe_54b/checkpoint_2_300000-rank-148.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
model_moe_54b/checkpoint_2_300000-rank-151.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
model_moe_54b/checkpoint_2_300000-rank-217.pt filter=lfs diff=lfs merge=lfs -text
|
19 |
+
model_moe_54b/checkpoint_2_300000-rank-248.pt filter=lfs diff=lfs merge=lfs -text
|
20 |
+
model_moe_54b/checkpoint_2_300000-rank-5.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
model_moe_54b/checkpoint_2_300000-rank-110.pt filter=lfs diff=lfs merge=lfs -text
|
22 |
+
model_moe_54b/checkpoint_2_300000-rank-112.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
model_moe_54b/checkpoint_2_300000-rank-171.pt filter=lfs diff=lfs merge=lfs -text
|
24 |
+
model_moe_54b/checkpoint_2_300000-rank-178.pt filter=lfs diff=lfs merge=lfs -text
|
25 |
+
model_moe_54b/checkpoint_2_300000-rank-185.pt filter=lfs diff=lfs merge=lfs -text
|
26 |
+
model_moe_54b/checkpoint_2_300000-shared.pt filter=lfs diff=lfs merge=lfs -text
|
27 |
+
model_moe_54b/checkpoint_2_300000-rank-93.pt filter=lfs diff=lfs merge=lfs -text
|
28 |
+
model_moe_54b/checkpoint_2_300000-rank-94.pt filter=lfs diff=lfs merge=lfs -text
|
29 |
+
model_moe_54b/checkpoint_2_300000-rank-204.pt filter=lfs diff=lfs merge=lfs -text
|
30 |
+
model_moe_54b/checkpoint_2_300000-rank-215.pt filter=lfs diff=lfs merge=lfs -text
|
31 |
+
model_moe_54b/checkpoint_2_300000-rank-252.pt filter=lfs diff=lfs merge=lfs -text
|
32 |
+
model_moe_54b/checkpoint_2_300000-rank-35.pt filter=lfs diff=lfs merge=lfs -text
|
33 |
+
model_moe_54b/checkpoint_2_300000-rank-95.pt filter=lfs diff=lfs merge=lfs -text
|
34 |
+
model_moe_54b/checkpoint_2_300000-rank-114.pt filter=lfs diff=lfs merge=lfs -text
|
35 |
+
model_moe_54b/checkpoint_2_300000-rank-199.pt filter=lfs diff=lfs merge=lfs -text
|
36 |
+
model_moe_54b/checkpoint_2_300000-rank-224.pt filter=lfs diff=lfs merge=lfs -text
|
37 |
+
model_moe_54b/checkpoint_2_300000-rank-117.pt filter=lfs diff=lfs merge=lfs -text
|
38 |
+
model_moe_54b/checkpoint_2_300000-rank-19.pt filter=lfs diff=lfs merge=lfs -text
|
39 |
+
model_moe_54b/checkpoint_2_300000-rank-238.pt filter=lfs diff=lfs merge=lfs -text
|
40 |
+
model_moe_54b/checkpoint_2_300000-rank-218.pt filter=lfs diff=lfs merge=lfs -text
|
41 |
+
model_moe_54b/checkpoint_2_300000-rank-84.pt filter=lfs diff=lfs merge=lfs -text
|
42 |
+
model_moe_54b/checkpoint_2_300000-rank-160.pt filter=lfs diff=lfs merge=lfs -text
|
43 |
+
model_moe_54b/checkpoint_2_300000-rank-188.pt filter=lfs diff=lfs merge=lfs -text
|
44 |
+
model_moe_54b/checkpoint_2_300000-rank-194.pt filter=lfs diff=lfs merge=lfs -text
|
45 |
+
model_moe_54b/checkpoint_2_300000-rank-133.pt filter=lfs diff=lfs merge=lfs -text
|
46 |
+
model_moe_54b/checkpoint_2_300000-rank-38.pt filter=lfs diff=lfs merge=lfs -text
|
47 |
+
model_moe_54b/checkpoint_2_300000-rank-83.pt filter=lfs diff=lfs merge=lfs -text
|
48 |
+
model_moe_54b/checkpoint_2_300000-rank-62.pt filter=lfs diff=lfs merge=lfs -text
|
49 |
+
model_moe_54b/checkpoint_2_300000-rank-66.pt filter=lfs diff=lfs merge=lfs -text
|
50 |
+
model_moe_54b/checkpoint_2_300000-rank-197.pt filter=lfs diff=lfs merge=lfs -text
|
51 |
+
model_moe_54b/checkpoint_2_300000-rank-27.pt filter=lfs diff=lfs merge=lfs -text
|
52 |
+
model_moe_54b/checkpoint_2_300000-rank-47.pt filter=lfs diff=lfs merge=lfs -text
|
53 |
+
model_moe_54b/checkpoint_2_300000-rank-192.pt filter=lfs diff=lfs merge=lfs -text
|
54 |
+
model_moe_54b/checkpoint_2_300000-rank-6.pt filter=lfs diff=lfs merge=lfs -text
|
55 |
+
model_moe_54b/checkpoint_2_300000-rank-89.pt filter=lfs diff=lfs merge=lfs -text
|
56 |
+
model_moe_54b/checkpoint_2_300000-rank-101.pt filter=lfs diff=lfs merge=lfs -text
|
57 |
+
model_moe_54b/checkpoint_2_300000-rank-156.pt filter=lfs diff=lfs merge=lfs -text
|
58 |
+
model_moe_54b/checkpoint_2_300000-rank-176.pt filter=lfs diff=lfs merge=lfs -text
|
59 |
+
model_moe_54b/checkpoint_2_300000-rank-49.pt filter=lfs diff=lfs merge=lfs -text
|
60 |
+
model_moe_54b/checkpoint_2_300000-rank-85.pt filter=lfs diff=lfs merge=lfs -text
|
61 |
+
model_moe_54b/checkpoint_2_300000-rank-234.pt filter=lfs diff=lfs merge=lfs -text
|
62 |
+
model_moe_54b/checkpoint_2_300000-rank-43.pt filter=lfs diff=lfs merge=lfs -text
|
63 |
+
model_moe_54b/checkpoint_2_300000-rank-64.pt filter=lfs diff=lfs merge=lfs -text
|
64 |
+
model_moe_54b/checkpoint_2_300000-rank-243.pt filter=lfs diff=lfs merge=lfs -text
|
65 |
+
model_moe_54b/checkpoint_2_300000-rank-42.pt filter=lfs diff=lfs merge=lfs -text
|
66 |
+
model_moe_54b/checkpoint_2_300000-rank-69.pt filter=lfs diff=lfs merge=lfs -text
|
67 |
+
model_moe_54b/checkpoint_2_300000-rank-18.pt filter=lfs diff=lfs merge=lfs -text
|
68 |
+
model_moe_54b/checkpoint_2_300000-rank-190.pt filter=lfs diff=lfs merge=lfs -text
|
69 |
+
model_moe_54b/checkpoint_2_300000-rank-240.pt filter=lfs diff=lfs merge=lfs -text
|
70 |
+
model_moe_54b/checkpoint_2_300000-rank-80.pt filter=lfs diff=lfs merge=lfs -text
|
71 |
+
model_moe_54b/checkpoint_2_300000-rank-137.pt filter=lfs diff=lfs merge=lfs -text
|
72 |
+
model_moe_54b/checkpoint_2_300000-rank-182.pt filter=lfs diff=lfs merge=lfs -text
|
73 |
+
model_moe_54b/checkpoint_2_300000-rank-222.pt filter=lfs diff=lfs merge=lfs -text
|
74 |
+
model_moe_54b/checkpoint_2_300000-rank-205.pt filter=lfs diff=lfs merge=lfs -text
|
75 |
+
model_moe_54b/checkpoint_2_300000-rank-236.pt filter=lfs diff=lfs merge=lfs -text
|
76 |
+
model_moe_54b/checkpoint_2_300000-rank-86.pt filter=lfs diff=lfs merge=lfs -text
|
77 |
+
model_moe_54b/checkpoint_2_300000-rank-99.pt filter=lfs diff=lfs merge=lfs -text
|
78 |
+
model_moe_54b/checkpoint_2_300000-rank-108.pt filter=lfs diff=lfs merge=lfs -text
|
79 |
+
model_moe_54b/checkpoint_2_300000-rank-15.pt filter=lfs diff=lfs merge=lfs -text
|
80 |
+
model_moe_54b/checkpoint_2_300000-rank-173.pt filter=lfs diff=lfs merge=lfs -text
|
81 |
+
model_moe_54b/checkpoint_2_300000-rank-106.pt filter=lfs diff=lfs merge=lfs -text
|
82 |
+
model_moe_54b/checkpoint_2_300000-rank-147.pt filter=lfs diff=lfs merge=lfs -text
|
83 |
+
model_moe_54b/checkpoint_2_300000-rank-244.pt filter=lfs diff=lfs merge=lfs -text
|
84 |
+
model_moe_54b/checkpoint_2_300000-rank-144.pt filter=lfs diff=lfs merge=lfs -text
|
85 |
+
model_moe_54b/checkpoint_2_300000-rank-149.pt filter=lfs diff=lfs merge=lfs -text
|
86 |
+
model_moe_54b/checkpoint_2_300000-rank-223.pt filter=lfs diff=lfs merge=lfs -text
|
87 |
+
model_moe_54b/checkpoint_2_300000-rank-59.pt filter=lfs diff=lfs merge=lfs -text
|
88 |
+
model_moe_54b/checkpoint_2_300000-rank-103.pt filter=lfs diff=lfs merge=lfs -text
|
89 |
+
model_moe_54b/checkpoint_2_300000-rank-184.pt filter=lfs diff=lfs merge=lfs -text
|
90 |
+
model_moe_54b/checkpoint_2_300000-rank-24.pt filter=lfs diff=lfs merge=lfs -text
|
91 |
+
model_moe_54b/checkpoint_2_300000-rank-232.pt filter=lfs diff=lfs merge=lfs -text
|
92 |
+
model_moe_54b/checkpoint_2_300000-rank-30.pt filter=lfs diff=lfs merge=lfs -text
|
93 |
+
model_moe_54b/checkpoint_2_300000-rank-7.pt filter=lfs diff=lfs merge=lfs -text
|
94 |
+
model_moe_54b/checkpoint_2_300000-rank-181.pt filter=lfs diff=lfs merge=lfs -text
|
95 |
+
model_moe_54b/checkpoint_2_300000-rank-196.pt filter=lfs diff=lfs merge=lfs -text
|
96 |
+
model_moe_54b/checkpoint_2_300000-rank-206.pt filter=lfs diff=lfs merge=lfs -text
|
97 |
+
model_moe_54b/checkpoint_2_300000-rank-241.pt filter=lfs diff=lfs merge=lfs -text
|
98 |
+
model_moe_54b/checkpoint_2_300000-rank-230.pt filter=lfs diff=lfs merge=lfs -text
|
99 |
+
model_moe_54b/checkpoint_2_300000-rank-246.pt filter=lfs diff=lfs merge=lfs -text
|
100 |
+
model_moe_54b/checkpoint_2_300000-rank-50.pt filter=lfs diff=lfs merge=lfs -text
|
101 |
+
model_moe_54b/checkpoint_2_300000-rank-105.pt filter=lfs diff=lfs merge=lfs -text
|
102 |
+
model_moe_54b/checkpoint_2_300000-rank-132.pt filter=lfs diff=lfs merge=lfs -text
|
103 |
+
model_moe_54b/checkpoint_2_300000-rank-14.pt filter=lfs diff=lfs merge=lfs -text
|
104 |
+
model_moe_54b/checkpoint_2_300000-rank-229.pt filter=lfs diff=lfs merge=lfs -text
|
105 |
+
model_moe_54b/checkpoint_2_300000-rank-235.pt filter=lfs diff=lfs merge=lfs -text
|
106 |
+
model_moe_54b/checkpoint_2_300000-rank-91.pt filter=lfs diff=lfs merge=lfs -text
|
107 |
+
model_moe_54b/checkpoint_2_300000-rank-65.pt filter=lfs diff=lfs merge=lfs -text
|
108 |
+
model_moe_54b/checkpoint_2_300000-rank-168.pt filter=lfs diff=lfs merge=lfs -text
|
109 |
+
model_moe_54b/checkpoint_2_300000-rank-225.pt filter=lfs diff=lfs merge=lfs -text
|
110 |
+
model_moe_54b/checkpoint_2_300000-rank-254.pt filter=lfs diff=lfs merge=lfs -text
|
111 |
+
model_moe_54b/checkpoint_2_300000-rank-220.pt filter=lfs diff=lfs merge=lfs -text
|
112 |
+
model_moe_54b/checkpoint_2_300000-rank-249.pt filter=lfs diff=lfs merge=lfs -text
|
113 |
+
model_moe_54b/checkpoint_2_300000-rank-157.pt filter=lfs diff=lfs merge=lfs -text
|
114 |
+
model_moe_54b/checkpoint_2_300000-rank-172.pt filter=lfs diff=lfs merge=lfs -text
|
115 |
+
model_moe_54b/checkpoint_2_300000-rank-186.pt filter=lfs diff=lfs merge=lfs -text
|
116 |
+
model_moe_54b/checkpoint_2_300000-rank-228.pt filter=lfs diff=lfs merge=lfs -text
|
117 |
+
model_moe_54b/checkpoint_2_300000-rank-4.pt filter=lfs diff=lfs merge=lfs -text
|
118 |
+
model_moe_54b/checkpoint_2_300000-rank-54.pt filter=lfs diff=lfs merge=lfs -text
|
119 |
+
model_moe_54b/checkpoint_2_300000-rank-61.pt filter=lfs diff=lfs merge=lfs -text
|
120 |
+
model_moe_54b/checkpoint_2_300000-rank-159.pt filter=lfs diff=lfs merge=lfs -text
|
121 |
+
model_moe_54b/checkpoint_2_300000-rank-180.pt filter=lfs diff=lfs merge=lfs -text
|
122 |
+
model_moe_54b/checkpoint_2_300000-rank-219.pt filter=lfs diff=lfs merge=lfs -text
|
123 |
+
model_moe_54b/checkpoint_2_300000-rank-153.pt filter=lfs diff=lfs merge=lfs -text
|
124 |
+
model_moe_54b/checkpoint_2_300000-rank-51.pt filter=lfs diff=lfs merge=lfs -text
|
125 |
+
model_moe_54b/checkpoint_2_300000-rank-92.pt filter=lfs diff=lfs merge=lfs -text
|
126 |
+
model_moe_54b/checkpoint_2_300000-rank-45.pt filter=lfs diff=lfs merge=lfs -text
|
127 |
+
model_moe_54b/checkpoint_2_300000-rank-52.pt filter=lfs diff=lfs merge=lfs -text
|
128 |
+
model_moe_54b/checkpoint_2_300000-rank-109.pt filter=lfs diff=lfs merge=lfs -text
|
129 |
+
model_moe_54b/checkpoint_2_300000-rank-121.pt filter=lfs diff=lfs merge=lfs -text
|
130 |
+
model_moe_54b/checkpoint_2_300000-rank-40.pt filter=lfs diff=lfs merge=lfs -text
|
131 |
+
model_moe_54b/checkpoint_2_300000-rank-107.pt filter=lfs diff=lfs merge=lfs -text
|
132 |
+
model_moe_54b/checkpoint_2_300000-rank-143.pt filter=lfs diff=lfs merge=lfs -text
|
133 |
+
model_moe_54b/checkpoint_2_300000-rank-26.pt filter=lfs diff=lfs merge=lfs -text
|
134 |
+
model_moe_54b/checkpoint_2_300000-rank-1.pt filter=lfs diff=lfs merge=lfs -text
|
135 |
+
model_moe_54b/checkpoint_2_300000-rank-102.pt filter=lfs diff=lfs merge=lfs -text
|
136 |
+
model_moe_54b/checkpoint_2_300000-rank-221.pt filter=lfs diff=lfs merge=lfs -text
|
137 |
+
model_moe_54b/checkpoint_2_300000-rank-28.pt filter=lfs diff=lfs merge=lfs -text
|
138 |
+
model_moe_54b/checkpoint_2_300000-rank-79.pt filter=lfs diff=lfs merge=lfs -text
|
139 |
+
model_moe_54b/checkpoint_2_300000-rank-100.pt filter=lfs diff=lfs merge=lfs -text
|
140 |
+
model_moe_54b/checkpoint_2_300000-rank-150.pt filter=lfs diff=lfs merge=lfs -text
|
141 |
+
model_moe_54b/checkpoint_2_300000-rank-72.pt filter=lfs diff=lfs merge=lfs -text
|
142 |
+
model_moe_54b/checkpoint_2_300000-rank-0.pt filter=lfs diff=lfs merge=lfs -text
|
143 |
+
model_moe_54b/checkpoint_2_300000-rank-131.pt filter=lfs diff=lfs merge=lfs -text
|
144 |
+
model_moe_54b/checkpoint_2_300000-rank-55.pt filter=lfs diff=lfs merge=lfs -text
|
145 |
+
model_moe_54b/checkpoint_2_300000-rank-141.pt filter=lfs diff=lfs merge=lfs -text
|
146 |
+
model_moe_54b/checkpoint_2_300000-rank-177.pt filter=lfs diff=lfs merge=lfs -text
|
147 |
+
model_moe_54b/checkpoint_2_300000-rank-29.pt filter=lfs diff=lfs merge=lfs -text
|
148 |
+
model_moe_54b/checkpoint_2_300000-rank-145.pt filter=lfs diff=lfs merge=lfs -text
|
149 |
+
model_moe_54b/checkpoint_2_300000-rank-58.pt filter=lfs diff=lfs merge=lfs -text
|
150 |
+
model_moe_54b/checkpoint_2_300000-rank-82.pt filter=lfs diff=lfs merge=lfs -text
|
151 |
+
model_moe_54b/checkpoint_2_300000-rank-87.pt filter=lfs diff=lfs merge=lfs -text
|
152 |
+
model_moe_54b/checkpoint_2_300000-rank-120.pt filter=lfs diff=lfs merge=lfs -text
|
153 |
+
model_moe_54b/checkpoint_2_300000-rank-130.pt filter=lfs diff=lfs merge=lfs -text
|
154 |
+
model_moe_54b/checkpoint_2_300000-rank-136.pt filter=lfs diff=lfs merge=lfs -text
|
155 |
+
model_moe_54b/checkpoint_2_300000-rank-226.pt filter=lfs diff=lfs merge=lfs -text
|
156 |
+
model_moe_54b/checkpoint_2_300000-rank-211.pt filter=lfs diff=lfs merge=lfs -text
|
157 |
+
model_moe_54b/checkpoint_2_300000-rank-111.pt filter=lfs diff=lfs merge=lfs -text
|
158 |
+
model_moe_54b/checkpoint_2_300000-rank-118.pt filter=lfs diff=lfs merge=lfs -text
|
159 |
+
model_moe_54b/checkpoint_2_300000-rank-20.pt filter=lfs diff=lfs merge=lfs -text
|
160 |
+
model_moe_54b/checkpoint_2_300000-rank-165.pt filter=lfs diff=lfs merge=lfs -text
|
161 |
+
model_moe_54b/checkpoint_2_300000-rank-17.pt filter=lfs diff=lfs merge=lfs -text
|
162 |
+
model_moe_54b/checkpoint_2_300000-rank-96.pt filter=lfs diff=lfs merge=lfs -text
|
163 |
+
model_moe_54b/checkpoint_2_300000-rank-231.pt filter=lfs diff=lfs merge=lfs -text
|
164 |
+
model_moe_54b/checkpoint_2_300000-rank-189.pt filter=lfs diff=lfs merge=lfs -text
|
165 |
+
model_moe_54b/checkpoint_2_300000-rank-191.pt filter=lfs diff=lfs merge=lfs -text
|
166 |
+
model_moe_54b/checkpoint_2_300000-rank-209.pt filter=lfs diff=lfs merge=lfs -text
|
167 |
+
model_moe_54b/checkpoint_2_300000-rank-164.pt filter=lfs diff=lfs merge=lfs -text
|
168 |
+
model_moe_54b/checkpoint_2_300000-rank-233.pt filter=lfs diff=lfs merge=lfs -text
|
169 |
+
model_moe_54b/checkpoint_2_300000-rank-75.pt filter=lfs diff=lfs merge=lfs -text
|
170 |
+
model_moe_54b/checkpoint_2_300000-rank-140.pt filter=lfs diff=lfs merge=lfs -text
|
171 |
+
model_moe_54b/checkpoint_2_300000-rank-154.pt filter=lfs diff=lfs merge=lfs -text
|
172 |
+
model_moe_54b/checkpoint_2_300000-rank-163.pt filter=lfs diff=lfs merge=lfs -text
|
173 |
+
model_moe_54b/checkpoint_2_300000-rank-146.pt filter=lfs diff=lfs merge=lfs -text
|
174 |
+
model_moe_54b/checkpoint_2_300000-rank-46.pt filter=lfs diff=lfs merge=lfs -text
|
175 |
+
model_moe_54b/checkpoint_2_300000-rank-210.pt filter=lfs diff=lfs merge=lfs -text
|
176 |
+
model_moe_54b/checkpoint_2_300000-rank-245.pt filter=lfs diff=lfs merge=lfs -text
|
177 |
+
model_moe_54b/checkpoint_2_300000-rank-247.pt filter=lfs diff=lfs merge=lfs -text
|
178 |
+
model_moe_54b/checkpoint_2_300000-rank-3.pt filter=lfs diff=lfs merge=lfs -text
|
179 |
+
model_moe_54b/checkpoint_2_300000-rank-32.pt filter=lfs diff=lfs merge=lfs -text
|
180 |
+
model_moe_54b/checkpoint_2_300000-rank-10.pt filter=lfs diff=lfs merge=lfs -text
|
181 |
+
model_moe_54b/checkpoint_2_300000-rank-12.pt filter=lfs diff=lfs merge=lfs -text
|
182 |
+
model_moe_54b/checkpoint_2_300000-rank-135.pt filter=lfs diff=lfs merge=lfs -text
|
183 |
+
model_moe_54b/checkpoint_2_300000-rank-9.pt filter=lfs diff=lfs merge=lfs -text
|
184 |
+
model_moe_54b/checkpoint_2_300000-rank-242.pt filter=lfs diff=lfs merge=lfs -text
|
185 |
+
model_moe_54b/checkpoint_2_300000-rank-37.pt filter=lfs diff=lfs merge=lfs -text
|
186 |
+
model_moe_54b/checkpoint_2_300000-rank-8.pt filter=lfs diff=lfs merge=lfs -text
|
187 |
+
model_moe_54b/checkpoint_2_300000-rank-25.pt filter=lfs diff=lfs merge=lfs -text
|
188 |
+
model_moe_54b/checkpoint_2_300000-rank-68.pt filter=lfs diff=lfs merge=lfs -text
|
189 |
+
model_moe_54b/checkpoint_2_300000-rank-88.pt filter=lfs diff=lfs merge=lfs -text
|
190 |
+
model_moe_54b/checkpoint_2_300000-rank-98.pt filter=lfs diff=lfs merge=lfs -text
|
191 |
+
model_moe_54b/checkpoint_2_300000-rank-179.pt filter=lfs diff=lfs merge=lfs -text
|
192 |
+
model_moe_54b/checkpoint_2_300000-rank-200.pt filter=lfs diff=lfs merge=lfs -text
|
193 |
+
model_moe_54b/checkpoint_2_300000-rank-23.pt filter=lfs diff=lfs merge=lfs -text
|
194 |
+
model_moe_54b/checkpoint_2_300000-rank-74.pt filter=lfs diff=lfs merge=lfs -text
|
195 |
+
model_moe_54b/checkpoint_2_300000-rank-152.pt filter=lfs diff=lfs merge=lfs -text
|
196 |
+
model_moe_54b/checkpoint_2_300000-rank-174.pt filter=lfs diff=lfs merge=lfs -text
|
197 |
+
model_moe_54b/checkpoint_2_300000-rank-56.pt filter=lfs diff=lfs merge=lfs -text
|
198 |
+
model_moe_54b/checkpoint_2_300000-rank-41.pt filter=lfs diff=lfs merge=lfs -text
|
199 |
+
model_moe_54b/checkpoint_2_300000-rank-128.pt filter=lfs diff=lfs merge=lfs -text
|
200 |
+
model_moe_54b/checkpoint_2_300000-rank-2.pt filter=lfs diff=lfs merge=lfs -text
|
201 |
+
model_moe_54b/checkpoint_2_300000-rank-39.pt filter=lfs diff=lfs merge=lfs -text
|
202 |
+
model_moe_54b/checkpoint_2_300000-rank-22.pt filter=lfs diff=lfs merge=lfs -text
|
203 |
+
model_moe_54b/checkpoint_2_300000-rank-122.pt filter=lfs diff=lfs merge=lfs -text
|
204 |
+
model_moe_54b/checkpoint_2_300000-rank-167.pt filter=lfs diff=lfs merge=lfs -text
|
205 |
+
model_moe_54b/checkpoint_2_300000-rank-208.pt filter=lfs diff=lfs merge=lfs -text
|
206 |
+
model_moe_54b/checkpoint_2_300000-rank-60.pt filter=lfs diff=lfs merge=lfs -text
|
207 |
+
model_moe_54b/checkpoint_2_300000-rank-126.pt filter=lfs diff=lfs merge=lfs -text
|
208 |
+
model_moe_54b/checkpoint_2_300000-rank-161.pt filter=lfs diff=lfs merge=lfs -text
|
209 |
+
model_moe_54b/checkpoint_2_300000-rank-195.pt filter=lfs diff=lfs merge=lfs -text
|
210 |
+
model_moe_54b/checkpoint_2_300000-rank-255.pt filter=lfs diff=lfs merge=lfs -text
|
211 |
+
model_moe_54b/checkpoint_2_300000-rank-71.pt filter=lfs diff=lfs merge=lfs -text
|
212 |
+
model_moe_54b/checkpoint_2_300000-rank-11.pt filter=lfs diff=lfs merge=lfs -text
|
213 |
+
model_moe_54b/checkpoint_2_300000-rank-125.pt filter=lfs diff=lfs merge=lfs -text
|
214 |
+
model_moe_54b/checkpoint_2_300000-rank-139.pt filter=lfs diff=lfs merge=lfs -text
|
215 |
+
model_moe_54b/checkpoint_2_300000-rank-207.pt filter=lfs diff=lfs merge=lfs -text
|
216 |
+
model_moe_54b/checkpoint_2_300000-rank-216.pt filter=lfs diff=lfs merge=lfs -text
|
217 |
+
model_moe_54b/checkpoint_2_300000-rank-237.pt filter=lfs diff=lfs merge=lfs -text
|
218 |
+
model_moe_54b/checkpoint_2_300000-rank-239.pt filter=lfs diff=lfs merge=lfs -text
|
219 |
+
model_moe_54b/checkpoint_2_300000-rank-251.pt filter=lfs diff=lfs merge=lfs -text
|
220 |
+
model_moe_54b/checkpoint_2_300000-rank-138.pt filter=lfs diff=lfs merge=lfs -text
|
221 |
+
model_moe_54b/checkpoint_2_300000-rank-193.pt filter=lfs diff=lfs merge=lfs -text
|
222 |
+
model_moe_54b/checkpoint_2_300000-rank-203.pt filter=lfs diff=lfs merge=lfs -text
|
223 |
+
model_moe_54b/checkpoint_2_300000-rank-124.pt filter=lfs diff=lfs merge=lfs -text
|
224 |
+
model_moe_54b/checkpoint_2_300000-rank-33.pt filter=lfs diff=lfs merge=lfs -text
|
225 |
+
model_moe_54b/checkpoint_2_300000-rank-73.pt filter=lfs diff=lfs merge=lfs -text
|
226 |
+
model_moe_54b/checkpoint_2_300000-rank-34.pt filter=lfs diff=lfs merge=lfs -text
|
227 |
+
model_moe_54b/checkpoint_2_300000-rank-113.pt filter=lfs diff=lfs merge=lfs -text
|
228 |
+
model_moe_54b/checkpoint_2_300000-rank-123.pt filter=lfs diff=lfs merge=lfs -text
|
229 |
+
model_moe_54b/checkpoint_2_300000-rank-57.pt filter=lfs diff=lfs merge=lfs -text
|
230 |
+
model_moe_54b/checkpoint_2_300000-rank-250.pt filter=lfs diff=lfs merge=lfs -text
|
231 |
+
model_moe_54b/checkpoint_2_300000-rank-253.pt filter=lfs diff=lfs merge=lfs -text
|
232 |
+
model_moe_54b/checkpoint_2_300000-rank-119.pt filter=lfs diff=lfs merge=lfs -text
|
233 |
+
model_moe_54b/checkpoint_2_300000-rank-187.pt filter=lfs diff=lfs merge=lfs -text
|
234 |
+
model_moe_54b/checkpoint_2_300000-rank-198.pt filter=lfs diff=lfs merge=lfs -text
|
235 |
+
model_moe_54b/checkpoint_2_300000-rank-142.pt filter=lfs diff=lfs merge=lfs -text
|
236 |
+
model_moe_54b/checkpoint_2_300000-rank-227.pt filter=lfs diff=lfs merge=lfs -text
|
237 |
+
model_moe_54b/checkpoint_2_300000-rank-155.pt filter=lfs diff=lfs merge=lfs -text
|
238 |
+
model_moe_54b/checkpoint_2_300000-rank-175.pt filter=lfs diff=lfs merge=lfs -text
|
239 |
+
model_moe_54b/checkpoint_2_300000-rank-36.pt filter=lfs diff=lfs merge=lfs -text
|
240 |
+
model_moe_54b/checkpoint_2_300000-rank-104.pt filter=lfs diff=lfs merge=lfs -text
|
241 |
+
model_moe_54b/checkpoint_2_300000-rank-127.pt filter=lfs diff=lfs merge=lfs -text
|
242 |
+
model_moe_54b/checkpoint_2_300000-rank-169.pt filter=lfs diff=lfs merge=lfs -text
|
243 |
+
model_moe_54b/checkpoint_2_300000-rank-158.pt filter=lfs diff=lfs merge=lfs -text
|
244 |
+
model_moe_54b/checkpoint_2_300000-rank-183.pt filter=lfs diff=lfs merge=lfs -text
|
245 |
+
model_moe_54b/checkpoint_2_300000-rank-70.pt filter=lfs diff=lfs merge=lfs -text
|
246 |
+
model_moe_54b/checkpoint_2_300000-rank-31.pt filter=lfs diff=lfs merge=lfs -text
|
247 |
+
model_moe_54b/checkpoint_2_300000-rank-97.pt filter=lfs diff=lfs merge=lfs -text
|
248 |
+
model_moe_54b/checkpoint_2_300000-rank-129.pt filter=lfs diff=lfs merge=lfs -text
|
249 |
+
model_moe_54b/checkpoint_2_300000-rank-162.pt filter=lfs diff=lfs merge=lfs -text
|
250 |
+
model_moe_54b/checkpoint_2_300000-rank-166.pt filter=lfs diff=lfs merge=lfs -text
|
251 |
+
model_moe_54b/checkpoint_2_300000-rank-21.pt filter=lfs diff=lfs merge=lfs -text
|
252 |
+
model_moe_54b/checkpoint_2_300000-rank-214.pt filter=lfs diff=lfs merge=lfs -text
|
253 |
+
model_moe_54b/checkpoint_2_300000-rank-53.pt filter=lfs diff=lfs merge=lfs -text
|
254 |
+
model_moe_54b/checkpoint_2_300000-rank-76.pt filter=lfs diff=lfs merge=lfs -text
|
255 |
+
model_moe_54b/checkpoint_2_300000-rank-116.pt filter=lfs diff=lfs merge=lfs -text
|
256 |
+
model_moe_54b/checkpoint_2_300000-rank-13.pt filter=lfs diff=lfs merge=lfs -text
|
257 |
+
model_moe_54b/checkpoint_2_300000-rank-202.pt filter=lfs diff=lfs merge=lfs -text
|
convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import argparse
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from transformers import NllbMoeConfig, NllbMoeModel
|
22 |
+
from transformers.modeling_utils import dtype_byte_size
|
23 |
+
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
24 |
+
|
25 |
+
|
26 |
+
# 'encoder.layers.7.moe_layer.experts.0.fc2.bias', 'encoder.layers.11.moe_layer.experts.0.fc1.weight',
|
27 |
+
|
28 |
+
|
29 |
+
def remove_ignore_keys_(state_dict):
|
30 |
+
ignore_keys = [
|
31 |
+
"encoder.version",
|
32 |
+
"decoder.version",
|
33 |
+
"model.encoder.version",
|
34 |
+
"model.decoder.version",
|
35 |
+
"decoder.output_projection.weight",
|
36 |
+
"_float_tensor",
|
37 |
+
"encoder.embed_positions._float_tensor",
|
38 |
+
"decoder.embed_positions._float_tensor",
|
39 |
+
]
|
40 |
+
for k in ignore_keys:
|
41 |
+
state_dict.pop(k, None)
|
42 |
+
|
43 |
+
|
44 |
+
def make_linear_from_emb(emb):
|
45 |
+
vocab_size, emb_size = emb.weight.shape
|
46 |
+
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
47 |
+
lin_layer.weight.data = emb.weight.data
|
48 |
+
return lin_layer
|
49 |
+
|
50 |
+
|
51 |
+
def rename_fairseq_keys(state_dict, expert_idx = None):
|
52 |
+
# 'encoder.layers.7.moe_layer.experts.0.fc2.bias' ->'encoder.layers.7.ffn.mlp.experts.0.fc2.bias'
|
53 |
+
# 'encoder.layers.7.fc2.bias' -> 'encoder.layers.7.ffn.mlp.fc2.bias'
|
54 |
+
# encoder.layers.7.wg -> encoder.layers.7.ffn.mlp.router.classifier
|
55 |
+
new_dict = {}
|
56 |
+
for old_key in state_dict.keys():
|
57 |
+
key = old_key
|
58 |
+
if "experts" in key:
|
59 |
+
key = key.replace("moe_layer.experts.0", f"ffn.mlp.experts.{expert_idx}")
|
60 |
+
elif "fc2" :
|
61 |
+
key = key.replace(".fc2.", ".ffn.mlp.fc2")
|
62 |
+
elif "fc1" :
|
63 |
+
key = key.replace(".fc1.", ".ffn.mlp.fc1")
|
64 |
+
elif "gate" in key:
|
65 |
+
key = key.replace(".moe_layer.gate.wg", ".ffn.mlp.router.classifier")
|
66 |
+
new_dict[key] = state_dict[old_key]
|
67 |
+
return new_dict
|
68 |
+
|
69 |
+
|
70 |
+
def shard_on_the_fly(
|
71 |
+
switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME
|
72 |
+
):
|
73 |
+
sharded_state_dicts = []
|
74 |
+
current_block = {}
|
75 |
+
total_size = 0
|
76 |
+
os.makedirs(dump_path, exist_ok=True)
|
77 |
+
|
78 |
+
for expert in range(num_experts):
|
79 |
+
expert_path = switch_checkpoint_path + f"-rank-{expert}.pt"
|
80 |
+
expert_state = torch.load(expert_path)["model"]
|
81 |
+
remove_ignore_keys_(expert_state)
|
82 |
+
expert_state = rename_fairseq_keys(expert_state, expert)
|
83 |
+
save_path = os.path.join(
|
84 |
+
dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin")
|
85 |
+
)
|
86 |
+
torch.save(expert_state, save_path)
|
87 |
+
sharded_state_dicts.append(expert_state.keys())
|
88 |
+
total_size += sum([value.numel() for key, value in expert_state.items()]) * dtype_byte_size(
|
89 |
+
expert_state[list(expert_state)[0]].dtype
|
90 |
+
)
|
91 |
+
|
92 |
+
# Add the last block
|
93 |
+
save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin"))
|
94 |
+
shared_weights = torch.load(switch_checkpoint_path + "-shared.pt")["model"]
|
95 |
+
remove_ignore_keys_(shared_weights)
|
96 |
+
shared_weights = rename_fairseq_keys(shared_weights)
|
97 |
+
shared_weights["shared.weight"] = shared_weights["decoder.embed_tokens.weight"]
|
98 |
+
|
99 |
+
torch.save(shared_weights, save_path)
|
100 |
+
sharded_state_dicts.append(shared_weights.keys())
|
101 |
+
|
102 |
+
# If we only have one shard, we return it
|
103 |
+
if len(sharded_state_dicts) == 1:
|
104 |
+
return {weights_name: sharded_state_dicts[0]}, None
|
105 |
+
|
106 |
+
# Otherwise, let's build the index
|
107 |
+
weight_map = {}
|
108 |
+
shards = {}
|
109 |
+
for idx, shard in enumerate(sharded_state_dicts):
|
110 |
+
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
111 |
+
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
|
112 |
+
os.rename(temp_filename, os.path.join(dump_path, shard_file))
|
113 |
+
for key in shard:
|
114 |
+
weight_map[key] = shard_file
|
115 |
+
|
116 |
+
# Add the metadata
|
117 |
+
metadata = {"total_size": total_size}
|
118 |
+
index = {"metadata": metadata, "weight_map": weight_map}
|
119 |
+
|
120 |
+
with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
121 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
122 |
+
f.write(content)
|
123 |
+
|
124 |
+
return metadata, index
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
parser = argparse.ArgumentParser()
|
129 |
+
# Required parameters
|
130 |
+
parser.add_argument(
|
131 |
+
"--nllb_moe_checkpoint_path",
|
132 |
+
default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/model_moe_54b/checkpoint_2_300000",
|
133 |
+
type=str,
|
134 |
+
required=False,
|
135 |
+
help="Path to a directory containing a folder per layer. Follows the original Google format.",
|
136 |
+
)
|
137 |
+
parser.add_argument("--dtype", default="float32", type=str, required=False, help="dtype of the saved model")
|
138 |
+
parser.add_argument(
|
139 |
+
"--pytorch_dump_folder_path",
|
140 |
+
default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/hf-converted-moe-54b",
|
141 |
+
type=str,
|
142 |
+
required=False,
|
143 |
+
help="Path to the output pytorch model.",
|
144 |
+
)
|
145 |
+
args = parser.parse_args()
|
146 |
+
# metadata, index = shard_on_the_fly(
|
147 |
+
# args.nllb_moe_checkpoint_path,
|
148 |
+
# args.pytorch_dump_folder_path,
|
149 |
+
# 128,
|
150 |
+
# args.dtype,
|
151 |
+
# )
|
152 |
+
|
153 |
+
|
154 |
+
config = NllbMoeConfig.from_pretrained(
|
155 |
+
"facebook/nllb-200-3.3B",
|
156 |
+
num_sparse_encoder_layers=4,
|
157 |
+
num_sparse_decoder_layers=4,
|
158 |
+
)
|
159 |
+
config.save_pretrained(args.pytorch_dump_folder_path)
|
160 |
+
|
161 |
+
|
162 |
+
model = NllbMoeModel(config)
|
163 |
+
model.save_pretrained(args.pytorch_dump_folder_path)
|
164 |
+
# model.push_to_hub("ArthurZ/nllb-moe-54b", use_auth_token="")
|
165 |
+
# model.save_pretrained(args.pytorch_dump_folder_path)
|
script.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import os
|
3 |
+
from transformers import NllbTokenizer
|
4 |
+
#from megatron.initialize import initialize_megatron
|
5 |
+
from fairseq import checkpoint_utils, tasks
|
6 |
+
from transformers import NllbMoeForConditionalGeneration
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
import torch.distributed as dist
|
11 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
12 |
+
os.environ['MASTER_PORT'] = '12355'
|
13 |
+
|
14 |
+
# initialize the process group
|
15 |
+
dist.init_process_group("gloo", rank=0, world_size=1)
|
16 |
+
|
17 |
+
|
18 |
+
path = "/home/arthur_huggingface_co/fairseq/weights/checkpoints/model_moe_54b"
|
19 |
+
hf_path = "/home/arthur/facebook/nllb-moe"
|
20 |
+
|
21 |
+
tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
22 |
+
tokenizer.save_pretrained(path)
|
23 |
+
|
24 |
+
# load the rank-0, which will merge all the states
|
25 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
26 |
+
os.path.join(path, "checkpoint_2_300000-rank-0.pt"),
|
27 |
+
is_moe=True,
|
28 |
+
)
|
29 |
+
cfg = state["cfg"]
|
30 |
+
# cfg.model.moe_expert_count=256, the checkpoint has more experts than the configuration available with the
|
31 |
+
# `state`. This is strange
|
32 |
+
# cfg.model.ddp_backend = ""
|
33 |
+
# 1. build the task to make sure that the embedding layers will be built?
|
34 |
+
# There are 256 experts, not 128
|
35 |
+
from fairseq import models, quantization_utils
|
36 |
+
|
37 |
+
# build the model
|
38 |
+
model = models.build_model(cfg.model, None, from_checkpoint=True)
|
39 |
+
# model = quantization_utils.quantize_model_scalar(model, args)
|
40 |
+
|
41 |
+
# load the merged state dict in the built model.
|
42 |
+
model.load_state_dict(
|
43 |
+
state["model"], strict=False, model_cfg=cfg.model
|
44 |
+
)
|
45 |
+
model = model.eval()
|
46 |
+
|
47 |
+
|
48 |
+
tokenizer = NllbTokenizer.from_pretrained(
|
49 |
+
"facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn"
|
50 |
+
)
|
51 |
+
|
52 |
+
src_text = "Life is like a box of chocolates."
|
53 |
+
tgt_text = "La vie est comme une boîte de chocolat."
|
54 |
+
|
55 |
+
model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt").input_ids
|
56 |
+
with torch.no_grad():
|
57 |
+
logits = model(model_inputs,len(model_inputs),torch.tensor([[2, tokenizer.lang_code_to_id["fra_Latn"]]]))[0]
|
58 |
+
pred_next_token = torch.argmax(logits[0, -1], -1)
|
59 |
+
next_token = tokenizer.decode(pred_next_token)
|
60 |
+
print(f"Next word: {next_token}")
|
61 |
+
print("-------------")
|
62 |
+
|
63 |
+
# forward passes
|
64 |
+
def single_batch_forward_logits(prompts):
|
65 |
+
input_ids = tokenizer(prompts, return_tensors="pt").input_ids
|
66 |
+
input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1)
|
67 |
+
input_ids = input_ids
|
68 |
+
with torch.no_grad():
|
69 |
+
logits = model(input_ids, len(input_ids), input_ids)[0]
|
70 |
+
return logits
|
71 |
+
|
72 |
+
|
73 |
+
# Generating with fairseq:
|
74 |
+
from fairseq.models.transformer_lm import TransformerLanguageModel
|
75 |
+
custom_lm = TransformerLanguageModel.from_pretrained('/path/to/model/dir', 'checkpoint100.pt', tokenizer='moses', bpe='fastbpe')
|
76 |
+
custom_lm.sample('Barack Obama', beam=5)
|
77 |
+
|
78 |
+
# myabe use the /home/arthur_huggingface_co/fairseq/examples/nllb/modeling/evaluation/conf/generate_multi.yaml
|
79 |
+
# and the generate multi.py script
|
80 |
+
# There is also generate.py which contains all of the generation methods.
|
81 |
+
# Let's hack through this!
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
prompts = [
|
86 |
+
"Today is a beautiful day and I want to",
|
87 |
+
"In the city of",
|
88 |
+
"Paris is the capital of France and",
|
89 |
+
"Computers and mobile phones have taken",
|
90 |
+
]
|
91 |
+
|
92 |
+
print("Next word generation")
|
93 |
+
for prompt in prompts:
|
94 |
+
print("-------------")
|
95 |
+
print(f"Prompt: {prompt}...\n")
|
96 |
+
logits_fsq = single_batch_forward_logits(prompt)
|
97 |
+
pred_next_token = torch.argmax(logits_fsq[0, -1], -1)
|
98 |
+
next_token = tokenizer.convert_ids_to_tokens([pred_next_token])
|
99 |
+
next_token = next_token[0].replace("Ġ", "")
|
100 |
+
print(f"Next word: {next_token}")
|
101 |
+
print("-------------")
|
102 |
+
|
103 |
+
|
104 |
+
exit(0)
|
105 |
+
|
106 |
+
hf_model = NllbMoeForConditionalGeneration.from_pretrained(hf_path)
|
107 |
+
|
108 |
+
# forward hf
|
109 |
+
def forward_hf(prompts):
|
110 |
+
input_ids = tokenizer(prompts, return_tensors="pt").input_ids
|
111 |
+
input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1)
|
112 |
+
input_ids = input_ids
|
113 |
+
with torch.no_grad():
|
114 |
+
logits = hf_model(input_ids)[0]
|
115 |
+
return logits
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
print("Next word generation")
|
120 |
+
for prompt in prompts:
|
121 |
+
logits = forward_hf(prompt)
|
122 |
+
pred_next_token = torch.argmax(logits[0, -1], -1)
|
123 |
+
next_token = tokenizer.convert_ids_to_tokens([pred_next_token])
|
124 |
+
next_token = next_token[0].replace("Ġ", "")
|
125 |
+
print(f"Next word: {next_token}")
|
126 |
+
print("-------------")
|
127 |
+
|
128 |
+
print("Is equal:", torch.allclose(logits_fsq.cpu(), logits.cpu(), atol=1e-3))
|