shibing624
commited on
Commit
·
b410dd9
1
Parent(s):
cfbfdf4
Update README.md
Browse files
README.md
CHANGED
@@ -19,11 +19,11 @@ ChatGLM3-6B中文纠错LoRA模型
|
|
19 |
|
20 |
The overall performance of shibing624/chatglm3-6b-csc-chinese-lora on CSC **test**:
|
21 |
|
22 |
-
|prefix|input_text|
|
23 |
-
|:-- |:--- |:---
|
24 |
-
|
25 |
|
26 |
-
在CSC测试集上生成结果纠错准确率高,由于是基于
|
27 |
|
28 |
|
29 |
## Usage
|
@@ -53,21 +53,35 @@ pip install transformers
|
|
53 |
```
|
54 |
|
55 |
```python
|
56 |
-
import
|
57 |
-
from peft import PeftModel
|
58 |
-
from transformers import AutoModel, AutoTokenizer
|
59 |
|
60 |
-
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
model = PeftModel.from_pretrained(model, "shibing624/chatglm3-6b-csc-chinese-lora")
|
64 |
-
model = model.half().cuda() # fp16
|
65 |
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
sents = ['对下面中文拼写纠错:\n少先队员因该为老人让坐。',
|
68 |
-
'对下面中文拼写纠错:\n下个星期,我跟我朋唷打算去法国玩儿。']
|
69 |
for s in sents:
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
71 |
print(response)
|
72 |
```
|
73 |
|
|
|
19 |
|
20 |
The overall performance of shibing624/chatglm3-6b-csc-chinese-lora on CSC **test**:
|
21 |
|
22 |
+
|prefix|input_text|pred|
|
23 |
+
|:-- |:--- |:--- |
|
24 |
+
|对下面文本纠错:|少先队员因该为老人让坐。|少先队员应该为老人让座。|
|
25 |
|
26 |
+
在CSC测试集上生成结果纠错准确率高,由于是基于[THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b)模型,结果常常能带给人惊喜,不仅能纠错,还带有句子润色和改写功能。
|
27 |
|
28 |
|
29 |
## Usage
|
|
|
53 |
```
|
54 |
|
55 |
```python
|
56 |
+
import os
|
|
|
|
|
57 |
|
58 |
+
import torch
|
59 |
+
from peft import PeftModel
|
60 |
+
from transformers import AutoTokenizer, AutoModel
|
61 |
|
62 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
|
|
|
63 |
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
|
64 |
+
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).half().cuda()
|
65 |
+
model = PeftModel.from_pretrained(model, "shibing624/chatglm3-6b-csc-chinese-lora")
|
66 |
+
|
67 |
+
sents = ['对下面文本纠错\n\n少先队员因该为老人让坐。',
|
68 |
+
'对下面文本纠错\n\n下个星期,我跟我朋唷打算去法国玩儿。']
|
69 |
+
|
70 |
+
|
71 |
+
def get_prompt(user_query):
|
72 |
+
vicuna_prompt = "A chat between a curious user and an artificial intelligence assistant. " \
|
73 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions. " \
|
74 |
+
"USER: {query} ASSISTANT:"
|
75 |
+
return vicuna_prompt.format(query=user_query)
|
76 |
+
|
77 |
|
|
|
|
|
78 |
for s in sents:
|
79 |
+
q = get_prompt(s)
|
80 |
+
input_ids = tokenizer(q).input_ids
|
81 |
+
generation_kwargs = dict(max_new_tokens=128, do_sample=True, temperature=0.8)
|
82 |
+
outputs = model.generate(input_ids=torch.as_tensor([input_ids]).to('cuda'), **generation_kwargs)
|
83 |
+
output_tensor = outputs[0][len(input_ids):]
|
84 |
+
response = tokenizer.decode(output_tensor, skip_special_tokens=True)
|
85 |
print(response)
|
86 |
```
|
87 |
|