KoichiYasuoka
commited on
Commit
•
7c71338
1
Parent(s):
9309bc0
bug fix
Browse files
ud.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
from transformers import TokenClassificationPipeline
|
2 |
|
3 |
class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
4 |
-
def _forward(self,
|
5 |
import torch
|
6 |
-
v=
|
7 |
with torch.no_grad():
|
8 |
e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)]))
|
9 |
-
return {"logits":e.logits[:,1:-2,:],**
|
10 |
-
def postprocess(self,
|
11 |
import numpy
|
12 |
-
e=
|
13 |
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
14 |
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|
15 |
g=self.model.config.label2id["X|_|goeswith"]
|
@@ -25,7 +25,7 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
25 |
k,h=z[numpy.nanargmax(m[z,z])],numpy.nanmin(m)-numpy.nanmax(m)
|
26 |
m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
|
27 |
h=self.chu_liu_edmonds(m)
|
28 |
-
v=[(s,e) for s,e in
|
29 |
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
|
30 |
g="aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none"
|
31 |
if g:
|
@@ -34,7 +34,7 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
34 |
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
|
35 |
v[i-1]=(v[i-1][0],v.pop(i)[1])
|
36 |
q.pop(i)
|
37 |
-
t=
|
38 |
u="# text = "+t+"\n"
|
39 |
for i,(s,e) in enumerate(v):
|
40 |
u+="\t".join([str(i+1),t[s:e],t[s:e] if g else "_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
|
|
|
1 |
from transformers import TokenClassificationPipeline
|
2 |
|
3 |
class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
4 |
+
def _forward(self,model_inputs):
|
5 |
import torch
|
6 |
+
v=model_inputs["input_ids"][0].tolist()
|
7 |
with torch.no_grad():
|
8 |
e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)]))
|
9 |
+
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
10 |
+
def postprocess(self,model_outputs,**kwargs):
|
11 |
import numpy
|
12 |
+
e=model_outputs["logits"].numpy()
|
13 |
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
14 |
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|
15 |
g=self.model.config.label2id["X|_|goeswith"]
|
|
|
25 |
k,h=z[numpy.nanargmax(m[z,z])],numpy.nanmin(m)-numpy.nanmax(m)
|
26 |
m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
|
27 |
h=self.chu_liu_edmonds(m)
|
28 |
+
v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
|
29 |
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
|
30 |
g="aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none"
|
31 |
if g:
|
|
|
34 |
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
|
35 |
v[i-1]=(v[i-1][0],v.pop(i)[1])
|
36 |
q.pop(i)
|
37 |
+
t=model_outputs["sentence"].replace("\n"," ")
|
38 |
u="# text = "+t+"\n"
|
39 |
for i,(s,e) in enumerate(v):
|
40 |
u+="\t".join([str(i+1),t[s:e],t[s:e] if g else "_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
|