DhanOS commited on
Commit
9d83cd3
1 Parent(s): 0e227cf

Added simple inference script

Browse files
Files changed (1) hide show
  1. deepthought_inference.py +178 -0
deepthought_inference.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Suppress TensorFlow logging
4
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable oneDNN optimizations
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
7
+ import warnings
8
+
9
+
10
+ warnings.filterwarnings("ignore", message="A NumPy version >=")
11
+ logging.basicConfig(level=logging.ERROR)
12
+ logging.getLogger("transformers").setLevel(logging.ERROR)
13
+
14
+
15
+ # Check if Flash Attention is available
16
+ try:
17
+ import flash_attn # noqa: F401
18
+ flash_attn_exists = True
19
+ except ImportError:
20
+ flash_attn_exists = False
21
+
22
+
23
+ # Define the DeepthoughtModel class
24
+ class DeepthoughtModel:
25
+ def __init__(self):
26
+ self.model_name = "ruliad/deepthought-8b-llama-v0.01-alpha"
27
+ print(f"Loading model: {self.model_name}")
28
+
29
+ self.tokenizer = AutoTokenizer.from_pretrained(
30
+ self.model_name,
31
+ add_bos_token=False,
32
+ trust_remote_code=True,
33
+ padding="left",
34
+ torch_dtype=torch.bfloat16,
35
+ )
36
+
37
+ self.model = AutoModelForCausalLM.from_pretrained(
38
+ self.model_name,
39
+ torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ attn_implementation=("flash_attention_2" if flash_attn_exists else "default"),
42
+ use_cache=True,
43
+ trust_remote_code=True,
44
+ )
45
+
46
+ # Helper method to generate the initial prompt
47
+ def _get_initial_prompt(
48
+ self, query: str, system_message: str = None
49
+ ) -> str:
50
+ '''Helper method to generate the initial prompt format.'''
51
+ if system_message is None:
52
+ system_message = '''You are a superintelligent AI system, capable of comprehensive reasoning. When provided with <reasoning>, you must provide your logical reasoning chain to solve the user query. Be verbose with your outputs.'''
53
+
54
+ return f'''<|im_start|>system
55
+ {system_message}<|im_end|>
56
+
57
+ <|im_start|>user
58
+ {query}<|im_end|>
59
+
60
+ <|im_start|>reasoning
61
+ <reasoning>
62
+ [
63
+ {{
64
+ "step": 1,
65
+ "type": "problem_understanding",
66
+ "thought": "'''
67
+
68
+ # Method to generate reasoning given the prompt
69
+ def generate_reasoning(self, query: str, system_message: str = None) -> dict:
70
+ print('Generating reasoning...')
71
+
72
+ # Get and print prompt
73
+ prompt = self._get_initial_prompt(query, system_message)
74
+ print(prompt, end='')
75
+
76
+ # Tokenize the prompt
77
+ inputs = self.tokenizer(prompt, return_tensors='pt').input_ids.to(self.model.device)
78
+
79
+ try:
80
+
81
+ # Generate and stream reasoning
82
+ outputs = self.model.generate(
83
+ input_ids=inputs,
84
+ max_new_tokens=800,
85
+ do_sample=True,
86
+ temperature=0.2,
87
+ top_k=200,
88
+ top_p=1.0,
89
+ eos_token_id=self.tokenizer.eos_token_id,
90
+ streamer=TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True),
91
+ )
92
+
93
+ # Get the reasoning string
94
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
95
+
96
+ return {
97
+ 'raw_output': generated_text,
98
+ 'success': True,
99
+ 'error': None,
100
+ 'initial_prompt': prompt,
101
+ }
102
+
103
+ except Exception as e:
104
+ logging.error(f'Error during generation: {e}')
105
+ return {
106
+ 'raw_output': None,
107
+ 'success': False,
108
+ 'error': str(e),
109
+ 'initial_prompt': None,
110
+ }
111
+
112
+ # Method to generate the final output
113
+ def generate_final_output(self, reasoning_output: dict) -> dict:
114
+
115
+ # Get the reasoning text and create the full prompt for the final output
116
+ reasoning_text = reasoning_output['raw_output'].replace(reasoning_output['initial_prompt'], '')
117
+ full_prompt = f'''{reasoning_text}<|im_end|>
118
+
119
+ <|im_start|>assistant
120
+ '''
121
+
122
+ print('Generating final response...')
123
+
124
+ # Tokenize the full prompt
125
+ inputs = self.tokenizer(full_prompt, return_tensors='pt').input_ids.to(self.model.device)
126
+
127
+ try:
128
+
129
+ # Generate and stream the final output
130
+ _ = self.model.generate(
131
+ input_ids=inputs,
132
+ max_new_tokens=400,
133
+ do_sample=True,
134
+ temperature=0.1,
135
+ top_k=50,
136
+ top_p=0.9,
137
+ eos_token_id=self.tokenizer.eos_token_id,
138
+ streamer=TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
139
+ )
140
+
141
+ return {'success': True, 'error': None}
142
+
143
+ except Exception as e:
144
+ logging.error(f'Error during final generation: {e}')
145
+ return {'success': False, 'error': str(e)}
146
+
147
+
148
+ def main():
149
+ model = DeepthoughtModel()
150
+
151
+ # Test queries
152
+ queries = [
153
+ "We want you to tell us the answer to life, the universe and everything. We'd really like an answer, something simple.",
154
+ ]
155
+
156
+ # Process each query at a time (because we are streaming)
157
+ for query in queries:
158
+ print(f'\nProcessing query: {query}')
159
+ print('='*50)
160
+
161
+ # Reasoning
162
+ reasoning_result = model.generate_reasoning(query)
163
+ if not reasoning_result['success']:
164
+ print(f'\nError in reasoning: {reasoning_result["error"]}')
165
+ print('='*50)
166
+ continue
167
+
168
+ print('-'*50)
169
+
170
+ # Final output
171
+ final_result = model.generate_final_output(reasoning_result)
172
+ if not final_result['success']:
173
+ print(f'\nError in final generation: {final_result["error"]}')
174
+
175
+ print('='*50)
176
+
177
+ if __name__ == '__main__':
178
+ main()