Upload inference.py
Browse files- inference.py +40 -0
inference.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
from transformers import AutoTokenizer, GPTJForCausalLM
|
4 |
+
|
5 |
+
def get_sent(sent:str) -> str:
|
6 |
+
input_text = '[BOS]' + sent + '[EOS][BOS]'
|
7 |
+
input_length = len(tokenizer.encode(input_text))
|
8 |
+
max_length = 786
|
9 |
+
input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0).to('cuda')
|
10 |
+
output_sentence = model.generate(
|
11 |
+
input_ids,
|
12 |
+
do_sample=True,
|
13 |
+
max_length=int(max_length),
|
14 |
+
num_return_sequences=1,
|
15 |
+
no_repeat_ngram_size=4,
|
16 |
+
num_beams=5,
|
17 |
+
early_stopping=True
|
18 |
+
)
|
19 |
+
generated_sequence = output_sentence[0].tolist()[input_length:]
|
20 |
+
decoded_sent = tokenizer.decode(generated_sequence, skip_special_tokens=False).strip()
|
21 |
+
return decoded_sent
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision="KoGPT6B-ryan1.5b")
|
26 |
+
model = GPTJForCausalLM.from_pretrained('./', torch_dtype=torch.float16)
|
27 |
+
model.cuda()
|
28 |
+
|
29 |
+
text = """
|
30 |
+
λ²μμμ΄ μ¨κ²μ μΈμ νκΈ° μ«μ΄μ.
|
31 |
+
2λ
μ λΆν° μ°μΈμ¦μ μλ¬λ Έμ΅λλ€. 2λ
λμ μ λ₯Ό μμκ°λ €κ³ μ μ°κ³ λμμ μ μ λ―Έλλ₯Ό μν΄ λμμμ΄ μκ²©μ¦ κ³΅λΆμ νμ
μ λ³νν΄ μμ΄μ.
|
32 |
+
νμ§λ§ λΆμμ ν μ¬λ¦¬ μνλ‘ μΈνμ¬ μ κΏκ³Ό μλ§μ κΈ°νλ€μ μ‘μ§ λͺ»νκ³ λλ§μ³λ²λ Έμ΅λλ€. μ΄μ§ΈμμΌκΉμ μ μ λ λ¨λ€λ³΄λ€ λ²ν°λ νμ΄ μ½νκ±ΈκΉμ.
|
33 |
+
λ€μ νμλ΄μ λμ λ ν΄λ³΄κ³ μ¬λ¬ μΌμ ν΄λ³΄λ©° λͺ¨λ μλμ§λ₯Ό μμλΆμλλ μ΄λλ νκΈμ΄ μ νμ
μ΄ μλκ³ κ°λ¨ν κ²°μ μ λ΄λ¦¬λ κ²μ΄ μ΄λ €μ κ²°κ΅ μ§μ₯μ λ κ·Έλ§λκ² λμμ΅λλ€. λμ½νλ€κ³ νκΈ°μλ 보ν΅μ¬λλ€μ΄ λ²ν°κΈ° μ΄λ €μνλ μ§μ’
μ νκ³ μκ±°λ μ.
|
34 |
+
λ무 κ³Όλνκ² νμ λ΄μ μ΄μ¬ν ν νμΌκΉμ? μ΄μ λ μ λ§ λ°°ν°λ¦¬κ° 1%λ λ¨μμμ§ μλ κ² κ°μμ. κ·Όλ° μ¬λ €λκΉ λΆμν΄μ§κ³ λΆλͺ¨λκ» λ―Έμν΄μ μνμ§λ§ μμμ΄λ λλ μ리 μ μ‘μμ μ μ§λμ ν
λ°μ.. μμΌλ‘ μ λ κ·Έλ₯ ν΄μμ μ·¨νλ κ²μ΄ λ§μκΉμ? μ λ μ΄λ€ μνμ μλ κ±ΈκΉμ?
|
35 |
+
"""
|
36 |
+
|
37 |
+
text = text.replace('\n','')
|
38 |
+
result = get_sent(text)
|
39 |
+
result = result.replace('μ¬μ°λ', 'λ§μΉ΄λ')
|
40 |
+
print('λ‘λν:', result)
|