khaimai commited on
Commit
d5b461d
·
verified ·
1 Parent(s): c262843

Upload 2 files

Browse files
Files changed (2) hide show
  1. modeling_functionary.py +109 -0
  2. tokenization_functionary.py +520 -0
modeling_functionary.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2024, MeetKai Inc. All rights reserved.
3
+ """PyTorch LLaMA model."""
4
+
5
+ import json
6
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+
11
+ from transformers.generation.configuration_utils import GenerationConfig
12
+ from transformers.generation.logits_process import LogitsProcessorList
13
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
14
+ from transformers.generation.utils import (
15
+ GenerateBeamDecoderOnlyOutput,
16
+ GenerateBeamEncoderDecoderOutput,
17
+ GenerateDecoderOnlyOutput,
18
+ GenerateEncoderDecoderOutput
19
+ )
20
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
21
+ from transformers.utils import logging
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.generation.streamers import BaseStreamer
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
31
+ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]
32
+ GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
33
+
34
+
35
+ class FunctionaryForCausalLM(LlamaForCausalLM):
36
+
37
+ def generate_tool_use(
38
+ self,
39
+ inputs: Optional[torch.Tensor] = None,
40
+ generation_config: Optional[GenerationConfig] = None,
41
+ logits_processor: Optional[LogitsProcessorList] = None,
42
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
43
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
44
+ synced_gpus: Optional[bool] = None,
45
+ assistant_model: Optional["PreTrainedModel"] = None,
46
+ streamer: Optional["BaseStreamer"] = None,
47
+ negative_prompt_ids: Optional[torch.Tensor] = None,
48
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
49
+ **kwargs,
50
+ ) -> Union[GenerateOutput, torch.LongTensor]:
51
+
52
+ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we use it to parse raw output
53
+
54
+ results = self.generate(
55
+ inputs=inputs,
56
+ generation_config=generation_config,
57
+ logits_processor=logits_processor,
58
+ stopping_criteria=stopping_criteria,
59
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
60
+ synced_gpus=synced_gpus,
61
+ assistant_model=assistant_model,
62
+ streamer=streamer,
63
+ negative_prompt_ids=negative_prompt_ids,
64
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
65
+ **kwargs,
66
+ )
67
+
68
+ input_ids = kwargs.pop("input_ids")
69
+ function_call_token = "<|reserved_special_token_249|>"
70
+
71
+ correct_results = []
72
+ for input_id, result in zip(input_ids, results):
73
+ final_output_json = {"role": "assistant", "content": None, "tool_calls": None}
74
+ tool_calls = []
75
+ raw_output_str = tokenizer.decode(result[len(input_id):].cpu())
76
+ has_text = False if raw_output_str.startswith(function_call_token) else True
77
+ chunks = raw_output_str.split(function_call_token)
78
+ for i, chunk in enumerate(chunks):
79
+ if len(chunk) == 0:
80
+ continue
81
+
82
+ chunk = chunk.replace(tokenizer.pad_token, "")
83
+ if i == 0 and has_text is not False:
84
+ final_output_json["content"] = chunk.strip[:-len("<|eot_id|>")] if chunk.endswith("<|eot_id|>") else chunk
85
+ else:
86
+ tool_calls.append(
87
+ {
88
+ "name": chunk[: chunk.index("\n{")],
89
+ "arguments": chunk[chunk.index("\n{") + 1: -len("<|eot_id|>")] if chunk.endswith("<|eot_id|>") else chunk[chunk.index("\n{") + 1:]
90
+ }
91
+ )
92
+ if len(tool_calls) > 0:
93
+ final_output_json["tool_calls"] = tool_calls
94
+ final_output_str = json.dumps(final_output_json, indent=4)
95
+ final_output_ids = tokenizer(final_output_str, add_special_tokens=False)["input_ids"]
96
+ correct_results.append(
97
+ torch.cat(
98
+ (result[:len(input_id)].cpu(), torch.tensor(final_output_ids))
99
+ )
100
+ )
101
+ max_len = max([tensor.shape[0] for tensor in correct_results])
102
+ correct_results = [
103
+ torch.nn.functional.pad(
104
+ correct_result, (0, max_len - correct_result.shape[0]), value=tokenizer.eos_token_id
105
+ ) for correct_result in correct_results
106
+ ]
107
+ correct_results = torch.stack(correct_results)
108
+
109
+ return correct_results
tokenization_functionary.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, MeetKai Inc. All rights reserved.
2
+
3
+ from copy import deepcopy
4
+ import json
5
+ from typing import Any, Dict, List, Literal, Optional, Union
6
+
7
+ import jsonref
8
+ from pydantic import BaseModel, Field, model_validator
9
+ from typing_extensions import Self
10
+
11
+ from transformers.tokenization_utils_base import BatchEncoding
12
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
13
+ from transformers.utils import TensorType, logging
14
+
15
+
16
+ logger = logging.get_logger(__name__)
17
+ SYSTEM_PROMPT = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
18
+ CODE_INTERPRETER_SYSTEM_PROMPT = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."""
19
+
20
+ class Function(BaseModel):
21
+ name: str
22
+ description: Optional[str] = Field(default="")
23
+ parameters: Optional[dict] = None
24
+
25
+
26
+ class Tool(BaseModel):
27
+ type: Literal["function", "code_interpreter"]
28
+ function: Optional[Function] = None
29
+
30
+ @model_validator(mode="after")
31
+ def check_type_function_matches(self) -> Self:
32
+ if self.type == "function":
33
+ assert self.function is not None, '"function" must contain function description when `"type": "function"`'
34
+ else:
35
+ assert self.function is None, '"function" must not be provided when `"type": "code_interpreter"`'
36
+ return self
37
+
38
+
39
+ def convert_data_type(param_type: str) -> str:
40
+ """convert data_type to typescript data type
41
+
42
+ Args:
43
+ param_type (str): param_type
44
+
45
+ Returns:
46
+ str: param type in typescript
47
+ """
48
+ if param_type == "integer" or param_type == "float":
49
+ return "number"
50
+ return param_type
51
+
52
+
53
+ def get_param_type(param: Dict) -> str:
54
+ """get param_type of parameter
55
+
56
+ Args:
57
+ param (Dict): param dict in properties
58
+
59
+ Returns:
60
+ str: _description_
61
+ """
62
+ param_type = "any"
63
+ if "type" in param:
64
+ raw_param_type = param["type"]
65
+ if type(raw_param_type) is list:
66
+ param_type = " | ".join(raw_param_type)
67
+ else:
68
+ param_type = raw_param_type
69
+
70
+ else: # in many cases, the json schema contains: oneOf instead of "type"
71
+ if "oneOf" in param:
72
+ one_of_types = []
73
+ for item in param["oneOf"]:
74
+ if "type" in item:
75
+ one_of_types.append(convert_data_type(item["type"]))
76
+ one_of_types = list(set(one_of_types))
77
+ param_type = " | ".join(one_of_types)
78
+ return convert_data_type(param_type)
79
+
80
+
81
+ def get_format_param(param: Dict) -> Optional[str]:
82
+ """Get "format" from param. There are cases where format is not directly in param but in oneOf
83
+
84
+ Args:
85
+ param (Dict): _description_
86
+
87
+ Returns:
88
+ Optional[str]: _description_
89
+ """
90
+ if "format" in param:
91
+ return param["format"]
92
+ if "oneOf" in param:
93
+ formats = []
94
+ for item in param["oneOf"]:
95
+ if "format" in item:
96
+ formats.append(item["format"])
97
+ if len(formats) > 0:
98
+ return " or ".join(formats)
99
+ return None
100
+
101
+
102
+ def get_param_info(param: Dict) -> Optional[str]:
103
+ """get additional information about parameter such as: format, default value, min, max, ...
104
+
105
+ Args:
106
+ param (Dict): _description_
107
+
108
+ Returns:
109
+ Optional[str]: _description_
110
+ """
111
+ param_type = param.get("type", "any")
112
+ info_list = []
113
+ if "description" in param:
114
+ desc = param["description"]
115
+ if not desc.endswith("."):
116
+ desc += "."
117
+ info_list.append(desc)
118
+
119
+ if "default" in param:
120
+ default_value = param["default"]
121
+ if param_type == "string":
122
+ default_value = f'"{default_value}"' # if string --> add ""
123
+ info_list.append(f"Default={default_value}.")
124
+
125
+ format_param = get_format_param(param)
126
+ if format_param is not None:
127
+ info_list.append("Format=" + format_param)
128
+
129
+ for field, field_name in [
130
+ ("maximum", "Maximum"),
131
+ ("minimum", "Minimum"),
132
+ ("maxLength", "Maximum length"),
133
+ ("minLength", "Minimum length"),
134
+ ]:
135
+ if field in param:
136
+ info_list.append(f"{field_name}=" + str(param[field]))
137
+
138
+ if len(info_list) > 0:
139
+ result = "// " + " ".join(info_list)
140
+ result = result.replace("\n", " ")
141
+ return result
142
+ return None
143
+
144
+
145
+ def append_new_param_info(
146
+ info_list: List[str],
147
+ param_declaration: str,
148
+ comment_info: Optional[str],
149
+ examples_info: List,
150
+ depth: int,
151
+ ):
152
+ """Append a new parameter with comment to the info_list
153
+
154
+ Args:
155
+ info_lines (List[str]): current info_list
156
+ param_declaration (str): param: type
157
+ comment_info (Optional[str]): information of comment
158
+ examples_info (List): information of examples given
159
+ depth (int): level of nested param
160
+ """
161
+ offset = ""
162
+ if depth >= 1:
163
+ offset = "".join([" " for _ in range(depth)])
164
+ if comment_info is not None:
165
+ # if depth == 0: # format: //comment\nparam: type
166
+ info_list.append(f"{offset}{comment_info}")
167
+ if len(examples_info) > 0:
168
+ for example in examples_info:
169
+ info_list.append(f"{offset}{example}")
170
+ info_list.append(f"{offset}{param_declaration}")
171
+ # else: # format: param: type // comment
172
+ # info_list.append(f"{offset}{param_declaration} {comment_info}")
173
+ else:
174
+ info_list.append(f"{offset}{param_declaration}")
175
+
176
+
177
+ def get_examples_info(param_name: str, examples: List) -> List:
178
+ """get information about examples provided
179
+
180
+ Args:
181
+ param_name (str): _description_
182
+ examples (List): _description_
183
+
184
+ Returns:
185
+ List: _description_
186
+ """
187
+ examples_list = [f"// Example {param_name}:"]
188
+ for example in examples:
189
+ if isinstance(example, dict) or isinstance(example, list):
190
+ example_str = json.dumps(example, ensure_ascii=False).replace('\n', '\\n')
191
+ else:
192
+ example_str = str(example).replace('\n', '\\n')
193
+ examples_list.append(f"// {example_str}")
194
+
195
+ return examples_list
196
+
197
+
198
+ def get_enum_option_str(enum_options: List) -> str:
199
+ """get enum option separated by: "|"
200
+
201
+ Args:
202
+ enum_options (List): list of options
203
+
204
+ Returns:
205
+ _type_: concatenation of options separated by "|"
206
+ """
207
+ # if each option is string --> add quote
208
+ return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
209
+
210
+
211
+ def get_array_typescript(
212
+ param_name: Optional[str], param_dic: dict, depth: int = 0
213
+ ) -> str:
214
+ """recursive implementation for generating type script of array
215
+
216
+ Args:
217
+ param_name (Optional[str]): name of param, optional
218
+ param_dic (dict): param_dic
219
+ depth (int, optional): nested level. Defaults to 0.
220
+
221
+ Returns:
222
+ _type_: typescript of array
223
+ """
224
+ offset = ""
225
+ if depth >= 1:
226
+ offset = "".join([" " for _ in range(depth)])
227
+ items_info = param_dic.get("items", {})
228
+
229
+ if len(items_info) == 0:
230
+ if param_name is not None:
231
+ return f"{offset}{param_name}: []"
232
+ else:
233
+ return "[]"
234
+ array_type = get_param_type(items_info)
235
+ if array_type == "object":
236
+ info_lines = []
237
+ child_lines = get_parameter_typescript(
238
+ items_info.get("properties", {}), items_info.get("required", []), depth + 1
239
+ )
240
+ # if comment_info is not None:
241
+ # info_lines.append(f"{offset}{comment_info}")
242
+ if param_name is not None:
243
+ info_lines.append(f"{offset}{param_name}" + ": {")
244
+ else:
245
+ info_lines.append(f"{offset}" + "{")
246
+ info_lines.extend(child_lines)
247
+ info_lines.append(f"{offset}" + "}[]")
248
+ return "\n".join(info_lines)
249
+
250
+ elif array_type == "array":
251
+ item_info = get_array_typescript(None, items_info, depth + 1)
252
+ if param_name is None:
253
+ return f"{item_info}[]"
254
+ return f"{offset}{param_name}: {item_info.strip()}[]"
255
+
256
+ else:
257
+ if "enum" in items_info:
258
+ item_type = get_enum_option_str(items_info["enum"])
259
+ if param_name is None:
260
+ return f"({item_type})[]"
261
+ else:
262
+ return f"{offset}{param_name}: ({item_type})[]"
263
+ else:
264
+ if param_name is None:
265
+ return f"{array_type}[]"
266
+ else:
267
+ return f"{offset}{param_name}: {array_type}[],"
268
+
269
+
270
+ def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
271
+ """Recursion, returning the information about parameters including data type, description and other information
272
+ These kinds of information will be put into the prompt
273
+
274
+ Args:
275
+ properties (_type_): properties in parameters
276
+ required_params (_type_): List of required parameters
277
+ depth (int, optional): the depth of params (nested level). Defaults to 0.
278
+
279
+ Returns:
280
+ _type_: list of lines containing information about all parameters
281
+ """
282
+ tp_lines = []
283
+ for param_name, param in properties.items():
284
+ # Sometimes properties have "required" field as a list of string.
285
+ # Even though its supposed to be not under properties. So we skip it
286
+ if not isinstance(param, dict):
287
+ continue
288
+ # Param Description
289
+ comment_info = get_param_info(param)
290
+ # Param Examples
291
+ examples_info = []
292
+ if "examples" in param:
293
+ examples_info = get_examples_info(param_name, param["examples"])
294
+ # Param Name declaration
295
+ param_declaration = f"{param_name}"
296
+ if isinstance(required_params, list):
297
+ if param_name not in required_params:
298
+ param_declaration += "?"
299
+ param_type = get_param_type(param)
300
+
301
+ offset = ""
302
+ if depth >= 1:
303
+ offset = "".join([" " for _ in range(depth)])
304
+
305
+ if param_type == "object": # param_type is object
306
+ child_lines = get_parameter_typescript(
307
+ param.get("properties", {}), param.get("required", []), depth + 1
308
+ )
309
+ if comment_info is not None:
310
+ tp_lines.append(f"{offset}{comment_info}")
311
+ if len(examples_info) > 0:
312
+ for example in examples_info:
313
+ tp_lines.append(f"{offset}{example}")
314
+
315
+ param_declaration += ": {"
316
+ tp_lines.append(f"{offset}{param_declaration}")
317
+ tp_lines.extend(child_lines)
318
+ tp_lines.append(f"{offset}" + "},")
319
+
320
+ elif param_type == "array": # param_type is an array
321
+ item_info = param.get("items", {})
322
+ if "type" not in item_info: # don't know type of array
323
+ param_declaration += ": [],"
324
+ append_new_param_info(
325
+ tp_lines, param_declaration, comment_info, examples_info, depth
326
+ )
327
+ else:
328
+ array_declaration = get_array_typescript(
329
+ param_declaration, param, depth
330
+ )
331
+ if not array_declaration.endswith(","):
332
+ array_declaration += ","
333
+ if comment_info is not None:
334
+ tp_lines.append(f"{offset}{comment_info}")
335
+ if len(examples_info) > 0:
336
+ for example in examples_info:
337
+ tp_lines.append(f"{offset}{example}")
338
+ tp_lines.append(array_declaration)
339
+ else:
340
+ if "enum" in param:
341
+ param_type = get_enum_option_str(param["enum"])
342
+ # param_type = " | ".join([f'"{v}"' for v in param["enum"]])
343
+ if "nullable" in param and param["nullable"] is True:
344
+ param_type += " | null"
345
+ param_declaration += f": {param_type},"
346
+ append_new_param_info(
347
+ tp_lines, param_declaration, comment_info, examples_info, depth
348
+ )
349
+
350
+ return tp_lines
351
+
352
+ def generate_schema_from_functions(
353
+ functions: List[Function], namespace="functions"
354
+ ) -> str:
355
+ """
356
+ Convert functions schema to a schema that language models can understand.
357
+ """
358
+
359
+ schema = "// Supported function definitions that should be called when necessary.\n"
360
+ schema += f"namespace {namespace} {{\n\n"
361
+
362
+ for function in functions:
363
+ # Convert a Function object to dict, if necessary
364
+ if not isinstance(function, dict):
365
+ function = function.model_dump()
366
+ function_name = function.get("name", None)
367
+ if function_name is None:
368
+ continue
369
+
370
+ description = function.get("description", "")
371
+ schema += f"// {description}\n"
372
+ schema += f"type {function_name}"
373
+
374
+ parameters = function.get("parameters", None)
375
+ if parameters is not None and parameters.get("properties") is not None:
376
+ parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters))
377
+ schema += " = (_: {\n"
378
+ required_params = parameters.get("required", [])
379
+ tp_lines = get_parameter_typescript(
380
+ parameters.get("properties"),
381
+ required_params,
382
+ 0,
383
+ )
384
+ schema += "\n".join(tp_lines)
385
+ schema += "\n}) => any;\n\n"
386
+ else:
387
+ # Doesn't have any parameters
388
+ schema += " = () => any;\n\n"
389
+
390
+ schema += f"}} // namespace {namespace}"
391
+
392
+ return schema
393
+
394
+ class FunctionaryTokenizer(PreTrainedTokenizerFast):
395
+ def apply_chat_template(
396
+ self,
397
+ conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], str],
398
+ tools: Optional[List[Dict[str, Any]]],
399
+ chat_template: Optional[str] = None,
400
+ add_generation_prompt: bool = False,
401
+ tokenize: bool = True,
402
+ padding: bool = False,
403
+ truncation: bool = False,
404
+ max_length: Optional[int] = None,
405
+ return_tensors: Optional[Union[str, TensorType]] = None,
406
+ return_dict: bool = False,
407
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
408
+ **kwargs,
409
+ ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
410
+
411
+ if return_dict and not tokenize:
412
+ raise ValueError(
413
+ "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
414
+ "of tokenizer outputs to return."
415
+ )
416
+
417
+ if tokenizer_kwargs is None:
418
+ tokenizer_kwargs = {}
419
+
420
+ using_default_template = False
421
+
422
+ # First, handle the cases when the model has a dict of multiple templates
423
+ if isinstance(self.chat_template, dict) or (
424
+ self.chat_template is None and isinstance(self.default_chat_template, dict)
425
+ ):
426
+ if self.chat_template is not None:
427
+ template_dict = self.chat_template
428
+ using_default_dict = False
429
+ else:
430
+ template_dict = self.default_chat_template
431
+ using_default_dict = True
432
+ if chat_template is not None and chat_template in template_dict:
433
+ # The user can pass the name of a template to the chat template argument instead of an entire template
434
+ chat_template = template_dict[chat_template]
435
+ if using_default_dict:
436
+ using_default_template = True
437
+ elif chat_template is None and "default" in template_dict:
438
+ chat_template = template_dict["default"]
439
+ if using_default_dict:
440
+ using_default_template = True
441
+ elif chat_template is None:
442
+ raise ValueError(
443
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
444
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
445
+ f"template names are {sorted(template_dict.keys())}."
446
+ )
447
+ elif chat_template is None:
448
+ # These are the cases when the model has a single template
449
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
450
+ if self.chat_template is not None:
451
+ chat_template = self.chat_template
452
+ else:
453
+ chat_template = self.default_chat_template
454
+ using_default_template = True
455
+
456
+ if using_default_template:
457
+ logger.warning_once(
458
+ "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
459
+ "very error-prone, because models are often trained with templates different from the class default! "
460
+ "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
461
+ "point any code depending on them will stop working. We recommend setting a valid chat template before "
462
+ "then to ensure that this model continues working without issues."
463
+ )
464
+
465
+ # Prepare tools/functions into schema
466
+ functions_pydantic_to_render = []
467
+ has_code_interpreter = False
468
+ for i in range(len(tools)):
469
+ tool_pydantic = Tool.model_validate(tools[i])
470
+ if tool_pydantic.type == "function":
471
+ functions_pydantic_to_render.append(tool_pydantic.function)
472
+ else:
473
+ has_code_interpreter = True
474
+ conversation.insert(0, {"role": "system", "content": generate_schema_from_functions(functions_pydantic_to_render)})
475
+ # Insert system prompt
476
+ system_prompt_to_use = SYSTEM_PROMPT if not has_code_interpreter else CODE_INTERPRETER_SYSTEM_PROMPT
477
+ conversation.insert(1, {"role": "system", "content": system_prompt_to_use})
478
+
479
+ # Compilation function uses a cache to avoid recompiling the same template
480
+ compiled_template = self._compile_jinja_template(chat_template)
481
+
482
+ if isinstance(conversation, (list, tuple)) and (
483
+ isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
484
+ ):
485
+ conversations = conversation
486
+ is_batched = True
487
+ else:
488
+ conversations = [conversation]
489
+ is_batched = False
490
+
491
+ rendered = []
492
+ template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
493
+ for chat in conversations:
494
+ if hasattr(chat, "messages"):
495
+ # Indicates it's a Conversation object
496
+ chat = chat.messages
497
+ rendered_chat = compiled_template.render(
498
+ messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
499
+ )
500
+ rendered.append(rendered_chat)
501
+
502
+ if not is_batched:
503
+ rendered = rendered[0]
504
+
505
+ if tokenize:
506
+ out = self(
507
+ rendered,
508
+ padding=padding,
509
+ truncation=truncation,
510
+ max_length=max_length,
511
+ add_special_tokens=False,
512
+ return_tensors=return_tensors,
513
+ **tokenizer_kwargs,
514
+ )
515
+ if return_dict:
516
+ return out
517
+ else:
518
+ return out["input_ids"]
519
+ else:
520
+ return rendered