kfkas commited on
Commit
1122542
β€’
1 Parent(s): 4f20837

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -0
README.md CHANGED
@@ -53,6 +53,52 @@ Llama-2-Ko-7b-Chat은 [beomi/llama-2-ko-7b 40B](https://huggingface.co/beomi/lla
53
  <img src=https://github.com/taemin6697/Paper_Review/assets/96530685/b9a697a2-ef06-4b1c-97e1-e72b20d9a8b5 style="max-width: 700px; width: 100%" />
54
  ---
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ## Note for oobabooga/text-generation-webui
57
 
58
  Remove `ValueError` at `load_tokenizer` function(line 109 or near), in `modules/models.py`.
 
53
  <img src=https://github.com/taemin6697/Paper_Review/assets/96530685/b9a697a2-ef06-4b1c-97e1-e72b20d9a8b5 style="max-width: 700px; width: 100%" />
54
  ---
55
 
56
+ ### Inference
57
+
58
+ ```python
59
+ def gen(x, model, tokenizer, device):
60
+ prompt = (
61
+ f"μ•„λž˜λŠ” μž‘μ—…μ„ μ„€λͺ…ν•˜λŠ” λͺ…λ Ήμ–΄μž…λ‹ˆλ‹€. μš”μ²­μ„ 적절히 μ™„λ£Œν•˜λŠ” 응닡을 μž‘μ„±ν•˜μ„Έμš”.\n\n### λͺ…λ Ήμ–΄:\n{x}\n\n### 응닡:"
62
+ )
63
+ len_prompt = len(prompt)
64
+ gened = model.generate(
65
+ **tokenizer(prompt, return_tensors="pt", return_token_type_ids=False).to(
66
+ device
67
+ ),
68
+ max_new_tokens=1024,
69
+ early_stopping=True,
70
+ do_sample=True,
71
+ top_k=10,
72
+ top_p=0.92,
73
+ no_repeat_ngram_size=3,
74
+ eos_token_id=2,
75
+ repetition_penalty=1.2,
76
+ num_beams=3
77
+ )
78
+ return tokenizer.decode(gened[0])[len_prompt:]
79
+
80
+ def LLM_infer(input):
81
+ device = (
82
+ torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
83
+ )
84
+ model_id = "kfkas/legal-llama-2-ko-7b-Chat"
85
+ model = AutoModelForCausalLM.from_pretrained(
86
+ model_id, device_map={"": 0},torch_dtype=torch.float16, low_cpu_mem_usage=True
87
+ )
88
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
89
+ model.eval()
90
+ model.config.use_cache = (True)
91
+ tokenizer.pad_token = tokenizer.eos_token
92
+ output = gen(input, model=model, tokenizer=tokenizer, device=device)
93
+
94
+ return output
95
+
96
+
97
+ if __name__ == "__main__":
98
+ text = LLM_infer("살인죄λ₯Ό μ•Œλ €μ€˜")
99
+ print(text)
100
+ ```
101
+
102
  ## Note for oobabooga/text-generation-webui
103
 
104
  Remove `ValueError` at `load_tokenizer` function(line 109 or near), in `modules/models.py`.