sasan commited on
Commit
8062dc6
·
1 Parent(s): 0950a4c

chore: Update TTS dependencies and remove unused imports

Browse files
Files changed (1) hide show
  1. kitt/core/model.py +49 -41
kitt/core/model.py CHANGED
@@ -19,6 +19,12 @@ from kitt.skills.common import config
19
 
20
  from .validator import validate_function_call_schema
21
 
 
 
 
 
 
 
22
 
23
  class FunctionCall(BaseModel):
24
  arguments: dict
@@ -240,45 +246,6 @@ def get_prompt(template, history, tools, schema, user_preferences, car_status=No
240
  return prompt
241
 
242
 
243
- def run_inference_ollama(prompt):
244
- data = {
245
- "prompt": prompt,
246
- # "streaming": False,
247
- # "model": "smangrul/llama-3-8b-instruct-function-calling",
248
- # "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
249
- # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
250
- "model": "interstellarninja/hermes-2-pro-llama-3-8b",
251
- # "model": "dolphin-llama3:8b",
252
- # "model": "dolphin-llama3:70b",
253
- "raw": True,
254
- "options": {
255
- "temperature": 0.7,
256
- # "max_tokens": 1500,
257
- "num_predict": 1500,
258
- # "mirostat": 1,
259
- # "mirostat_tau": 2,
260
- "repeat_penalty": 1.2,
261
- "top_k": 25,
262
- "top_p": 0.5,
263
- "num_ctx": 8000,
264
- # "stop": ["<|im_end|>"]
265
- # "num_predict": 1500,
266
- # "max_tokens": 1500,
267
- },
268
- }
269
-
270
- client = Client(host="http://localhost:11434")
271
- # out = ollama.generate(**data)
272
- out = client.generate(**data)
273
- res = out.pop("response")
274
- # Report prompt and eval tokens
275
- logger.warning(
276
- f"Prompt tokens: {out.get('prompt_eval_count')}, Response tokens: {out.get('eval_count')}"
277
- )
278
- logger.debug(f"Response from Ollama: {res}\nOut:{out}")
279
- return res
280
-
281
-
282
  def run_inference_step(
283
  depth, history, tools, schema_json, user_preferences, backend="ollama"
284
  ):
@@ -317,10 +284,12 @@ def run_inference_replicate(prompt):
317
 
318
  input = {
319
  "prompt": prompt,
320
- "temperature": 0.5,
321
  "system_prompt": "",
322
  "max_new_tokens": 1024,
323
- "repeat_penalty": 1.1,
 
 
324
  "prompt_template": "{prompt}",
325
  }
326
 
@@ -336,6 +305,45 @@ def run_inference_replicate(prompt):
336
  return out
337
 
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  def run_inference(prompt, backend="ollama"):
340
  prompt += AI_PREAMBLE
341
 
 
19
 
20
  from .validator import validate_function_call_schema
21
 
22
+ # Model Settings
23
+ TEMPERATURE = 0.5
24
+ REPEAT_PENALTY = 1.1
25
+ TOP_P = 0.9
26
+ TOP_K = 50
27
+
28
 
29
  class FunctionCall(BaseModel):
30
  arguments: dict
 
246
  return prompt
247
 
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  def run_inference_step(
250
  depth, history, tools, schema_json, user_preferences, backend="ollama"
251
  ):
 
284
 
285
  input = {
286
  "prompt": prompt,
287
+ "temperature": TEMPERATURE,
288
  "system_prompt": "",
289
  "max_new_tokens": 1024,
290
+ "repeat_penalty": REPEAT_PENALTY,
291
+ "top_p": TOP_P,
292
+ "top_k": TOP_K,
293
  "prompt_template": "{prompt}",
294
  }
295
 
 
305
  return out
306
 
307
 
308
+ def run_inference_ollama(prompt):
309
+ data = {
310
+ "prompt": prompt,
311
+ # "streaming": False,
312
+ # "model": "smangrul/llama-3-8b-instruct-function-calling",
313
+ # "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
314
+ # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
315
+ "model": "interstellarninja/hermes-2-pro-llama-3-8b",
316
+ # "model": "dolphin-llama3:8b",
317
+ # "model": "dolphin-llama3:70b",
318
+ "raw": True,
319
+ "options": {
320
+ "temperature": TEMPERATURE,
321
+ # "max_tokens": 1500,
322
+ "num_predict": 1500,
323
+ # "mirostat": 1,
324
+ # "mirostat_tau": 2,
325
+ "repeat_penalty": REPEAT_PENALTY,
326
+ "top_p": TOP_P,
327
+ "top_k": TOP_K,
328
+ "num_ctx": 8000,
329
+ # "stop": ["<|im_end|>"]
330
+ # "num_predict": 1500,
331
+ # "max_tokens": 1500,
332
+ },
333
+ }
334
+
335
+ client = Client(host="http://localhost:11434")
336
+ # out = ollama.generate(**data)
337
+ out = client.generate(**data)
338
+ res = out.pop("response")
339
+ # Report prompt and eval tokens
340
+ logger.warning(
341
+ f"Prompt tokens: {out.get('prompt_eval_count')}, Response tokens: {out.get('eval_count')}"
342
+ )
343
+ logger.debug(f"Response from Ollama: {res}\nOut:{out}")
344
+ return res
345
+
346
+
347
  def run_inference(prompt, backend="ollama"):
348
  prompt += AI_PREAMBLE
349