chore: Update TTS dependencies and remove unused imports
Browse files- kitt/core/model.py +49 -41
kitt/core/model.py
CHANGED
@@ -19,6 +19,12 @@ from kitt.skills.common import config
|
|
19 |
|
20 |
from .validator import validate_function_call_schema
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
class FunctionCall(BaseModel):
|
24 |
arguments: dict
|
@@ -240,45 +246,6 @@ def get_prompt(template, history, tools, schema, user_preferences, car_status=No
|
|
240 |
return prompt
|
241 |
|
242 |
|
243 |
-
def run_inference_ollama(prompt):
|
244 |
-
data = {
|
245 |
-
"prompt": prompt,
|
246 |
-
# "streaming": False,
|
247 |
-
# "model": "smangrul/llama-3-8b-instruct-function-calling",
|
248 |
-
# "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
|
249 |
-
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
|
250 |
-
"model": "interstellarninja/hermes-2-pro-llama-3-8b",
|
251 |
-
# "model": "dolphin-llama3:8b",
|
252 |
-
# "model": "dolphin-llama3:70b",
|
253 |
-
"raw": True,
|
254 |
-
"options": {
|
255 |
-
"temperature": 0.7,
|
256 |
-
# "max_tokens": 1500,
|
257 |
-
"num_predict": 1500,
|
258 |
-
# "mirostat": 1,
|
259 |
-
# "mirostat_tau": 2,
|
260 |
-
"repeat_penalty": 1.2,
|
261 |
-
"top_k": 25,
|
262 |
-
"top_p": 0.5,
|
263 |
-
"num_ctx": 8000,
|
264 |
-
# "stop": ["<|im_end|>"]
|
265 |
-
# "num_predict": 1500,
|
266 |
-
# "max_tokens": 1500,
|
267 |
-
},
|
268 |
-
}
|
269 |
-
|
270 |
-
client = Client(host="http://localhost:11434")
|
271 |
-
# out = ollama.generate(**data)
|
272 |
-
out = client.generate(**data)
|
273 |
-
res = out.pop("response")
|
274 |
-
# Report prompt and eval tokens
|
275 |
-
logger.warning(
|
276 |
-
f"Prompt tokens: {out.get('prompt_eval_count')}, Response tokens: {out.get('eval_count')}"
|
277 |
-
)
|
278 |
-
logger.debug(f"Response from Ollama: {res}\nOut:{out}")
|
279 |
-
return res
|
280 |
-
|
281 |
-
|
282 |
def run_inference_step(
|
283 |
depth, history, tools, schema_json, user_preferences, backend="ollama"
|
284 |
):
|
@@ -317,10 +284,12 @@ def run_inference_replicate(prompt):
|
|
317 |
|
318 |
input = {
|
319 |
"prompt": prompt,
|
320 |
-
"temperature":
|
321 |
"system_prompt": "",
|
322 |
"max_new_tokens": 1024,
|
323 |
-
"repeat_penalty":
|
|
|
|
|
324 |
"prompt_template": "{prompt}",
|
325 |
}
|
326 |
|
@@ -336,6 +305,45 @@ def run_inference_replicate(prompt):
|
|
336 |
return out
|
337 |
|
338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
def run_inference(prompt, backend="ollama"):
|
340 |
prompt += AI_PREAMBLE
|
341 |
|
|
|
19 |
|
20 |
from .validator import validate_function_call_schema
|
21 |
|
22 |
+
# Model Settings
|
23 |
+
TEMPERATURE = 0.5
|
24 |
+
REPEAT_PENALTY = 1.1
|
25 |
+
TOP_P = 0.9
|
26 |
+
TOP_K = 50
|
27 |
+
|
28 |
|
29 |
class FunctionCall(BaseModel):
|
30 |
arguments: dict
|
|
|
246 |
return prompt
|
247 |
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
def run_inference_step(
|
250 |
depth, history, tools, schema_json, user_preferences, backend="ollama"
|
251 |
):
|
|
|
284 |
|
285 |
input = {
|
286 |
"prompt": prompt,
|
287 |
+
"temperature": TEMPERATURE,
|
288 |
"system_prompt": "",
|
289 |
"max_new_tokens": 1024,
|
290 |
+
"repeat_penalty": REPEAT_PENALTY,
|
291 |
+
"top_p": TOP_P,
|
292 |
+
"top_k": TOP_K,
|
293 |
"prompt_template": "{prompt}",
|
294 |
}
|
295 |
|
|
|
305 |
return out
|
306 |
|
307 |
|
308 |
+
def run_inference_ollama(prompt):
|
309 |
+
data = {
|
310 |
+
"prompt": prompt,
|
311 |
+
# "streaming": False,
|
312 |
+
# "model": "smangrul/llama-3-8b-instruct-function-calling",
|
313 |
+
# "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
|
314 |
+
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
|
315 |
+
"model": "interstellarninja/hermes-2-pro-llama-3-8b",
|
316 |
+
# "model": "dolphin-llama3:8b",
|
317 |
+
# "model": "dolphin-llama3:70b",
|
318 |
+
"raw": True,
|
319 |
+
"options": {
|
320 |
+
"temperature": TEMPERATURE,
|
321 |
+
# "max_tokens": 1500,
|
322 |
+
"num_predict": 1500,
|
323 |
+
# "mirostat": 1,
|
324 |
+
# "mirostat_tau": 2,
|
325 |
+
"repeat_penalty": REPEAT_PENALTY,
|
326 |
+
"top_p": TOP_P,
|
327 |
+
"top_k": TOP_K,
|
328 |
+
"num_ctx": 8000,
|
329 |
+
# "stop": ["<|im_end|>"]
|
330 |
+
# "num_predict": 1500,
|
331 |
+
# "max_tokens": 1500,
|
332 |
+
},
|
333 |
+
}
|
334 |
+
|
335 |
+
client = Client(host="http://localhost:11434")
|
336 |
+
# out = ollama.generate(**data)
|
337 |
+
out = client.generate(**data)
|
338 |
+
res = out.pop("response")
|
339 |
+
# Report prompt and eval tokens
|
340 |
+
logger.warning(
|
341 |
+
f"Prompt tokens: {out.get('prompt_eval_count')}, Response tokens: {out.get('eval_count')}"
|
342 |
+
)
|
343 |
+
logger.debug(f"Response from Ollama: {res}\nOut:{out}")
|
344 |
+
return res
|
345 |
+
|
346 |
+
|
347 |
def run_inference(prompt, backend="ollama"):
|
348 |
prompt += AI_PREAMBLE
|
349 |
|