ganchengguang commited on
Commit
d6f3d4a
1 Parent(s): f93dde7

Upload test.py

Browse files
Files changed (1) hide show
  1. test.py +52 -0
test.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
2
+ import torch
3
+
4
+ def generate_response(model, tokenizer, instruction, input_text, temperature, top_p, top_k, repeat_penalty):
5
+ PROMPT = f'''### Instruction:
6
+ {instruction}
7
+ ### Input:
8
+ {input_text}
9
+
10
+ ### Response:'''
11
+
12
+ input_ids = tokenizer.encode(PROMPT, return_tensors='pt')
13
+ max_length = len(input_ids[0]) + 50 # Example, you can set your preferred value
14
+
15
+ # Set generation parameters within given ranges
16
+ gen_parameters = {
17
+ 'temperature': temperature,
18
+ 'top_p': top_p,
19
+ 'top_k': top_k,
20
+ 'repetition_penalty': repeat_penalty,
21
+ 'max_length': max_length,
22
+ 'max_new_tokens': 50 # Example, you can set your preferred value
23
+ }
24
+
25
+ output = model.generate(input_ids, **gen_parameters)
26
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
27
+
28
+ return response[len(PROMPT):] # Removing the prompt part
29
+
30
+
31
+ def main():
32
+ MODEL_NAME = 'Yoko-7B-Japanese-v1 ' # Replace with your model's file path or name
33
+
34
+ # Load pre-trained model and tokenizer
35
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
36
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)
37
+
38
+ instruction = '次の問題を回答してください。'
39
+ # instruction = 'Please answer following question.'
40
+ input_text = '東京は何国の都市ですか?'
41
+
42
+ # Example parameter values, you can modify these within the ranges you provided
43
+ temperature = 0.6
44
+ top_p = 0.7
45
+ top_k = 40
46
+ repeat_penalty = 1.1
47
+
48
+ response = generate_response(model, tokenizer, instruction, input_text, temperature, top_p, top_k, repeat_penalty)
49
+ print('response'+response)
50
+
51
+ if __name__ == '__main__':
52
+ main()