junnyu commited on
Commit
19829c3
1 Parent(s): f066d7e

update tf_model

Browse files
Files changed (2) hide show
  1. README.md +38 -13
  2. tf_model.h5 +3 -0
README.md CHANGED
@@ -2,44 +2,69 @@
2
  language: zh
3
  tags:
4
  - roformer
 
 
5
  inference: false
6
  ---
7
  ## 介绍
8
  ### tf版本
9
  https://github.com/ZhuiyiTechnology/roformer
10
 
11
- ### pytorch版本
12
  https://github.com/JunnYu/RoFormer_pytorch
13
 
14
  ## 安装
15
  ```bash
 
 
16
  pip install git+https://github.com/JunnYu/RoFormer_pytorch.git
17
  ```
18
 
19
- ## 使用
20
  ```python
21
  import torch
22
  from roformer import RoFormerForMaskedLM, RoFormerTokenizer
23
 
24
  text = "今天[MASK]很好,我[MASK]去公园玩。"
25
  tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
26
- model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
27
- inputs = tokenizer(text, return_tensors="pt")
28
  with torch.no_grad():
29
- outputs = model(**inputs).logits[0]
30
- outputs_sentence = ""
31
  for i, id in enumerate(tokenizer.encode(text)):
32
  if id == tokenizer.mask_token_id:
33
- tokens = tokenizer.convert_ids_to_tokens(outputs[i].topk(k=5)[1])
34
- outputs_sentence += "[" + "||".join(tokens) + "]"
35
  else:
36
- outputs_sentence += "".join(
37
  tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
38
- print(outputs_sentence)
39
- # RoFormer 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
40
- # PLUS WoBERT 今天[天气||阳光||天||心情||空气]很好,我[想||要||打算||准备||就]去公园玩。
41
- # WoBERT 今天[天气||阳光||天||心情||空气]很好,我[想||要||就||准备||也]去公园玩。
42
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ## 引用
44
 
45
  Bibtex:
 
2
  language: zh
3
  tags:
4
  - roformer
5
+ - pytorch
6
+ - tf2.0
7
  inference: false
8
  ---
9
  ## 介绍
10
  ### tf版本
11
  https://github.com/ZhuiyiTechnology/roformer
12
 
13
+ ### pytorch版本+tf2.0版本
14
  https://github.com/JunnYu/RoFormer_pytorch
15
 
16
  ## 安装
17
  ```bash
18
+ pip install roformer
19
+
20
  pip install git+https://github.com/JunnYu/RoFormer_pytorch.git
21
  ```
22
 
23
+ ## pytorch使用
24
  ```python
25
  import torch
26
  from roformer import RoFormerForMaskedLM, RoFormerTokenizer
27
 
28
  text = "今天[MASK]很好,我[MASK]去公园玩。"
29
  tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
30
+ pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
31
+ pt_inputs = tokenizer(text, return_tensors="pt")
32
  with torch.no_grad():
33
+ pt_outputs = pt_model(**pt_inputs).logits[0]
34
+ pt_outputs_sentence = "pytorch: "
35
  for i, id in enumerate(tokenizer.encode(text)):
36
  if id == tokenizer.mask_token_id:
37
+ tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1])
38
+ pt_outputs_sentence += "[" + "||".join(tokens) + "]"
39
  else:
40
+ pt_outputs_sentence += "".join(
41
  tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
42
+ print(pt_outputs_sentence)
43
+ # pytorch 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
 
 
44
  ```
45
+
46
+ ## tensorflow2.0使用
47
+ ```python
48
+ import tensorflow as tf
49
+ from roformer import RoFormerTokenizer, TFRoFormerForMaskedLM
50
+ text = "今天[MASK]很好,我[MASK]去公园玩。"
51
+ tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
52
+ tf_model = TFRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
53
+ tf_inputs = tokenizer(text, return_tensors="tf")
54
+ tf_outputs = tf_model(**tf_inputs, training=False).logits[0]
55
+ tf_outputs_sentence = "tf2.0: "
56
+ for i, id in enumerate(tokenizer.encode(text)):
57
+ if id == tokenizer.mask_token_id:
58
+ tokens = tokenizer.convert_ids_to_tokens(
59
+ tf.math.top_k(tf_outputs[i], k=5)[1])
60
+ tf_outputs_sentence += "[" + "||".join(tokens) + "]"
61
+ else:
62
+ tf_outputs_sentence += "".join(
63
+ tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
64
+ print(tf_outputs_sentence)
65
+ # tf2.0 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
66
+ ```
67
+
68
  ## 引用
69
 
70
  Bibtex:
tf_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02e774b33014b21e68f3b9e9d1452c287e3a97171b2906ab42f0ab66ea96c7c
3
+ size 650288296