File size: 8,426 Bytes
03fb4d3
 
c06c3da
 
03fb4d3
 
c06c3da
03fb4d3
 
 
 
 
 
 
 
 
 
 
 
 
 
208b123
 
4a1d2ab
feef0a1
4a1d2ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208b123
 
 
 
 
 
 
 
 
 
 
486a7e9
208b123
 
 
 
 
4a1d2ab
208b123
 
 
 
 
 
 
 
 
 
5eb0bbd
4a1d2ab
 
 
 
 
 
 
 
 
 
 
5eb0bbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a1d2ab
5eb0bbd
 
 
 
 
 
 
 
 
 
 
 
4a1d2ab
5eb0bbd
 
 
 
208b123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a1d2ab
208b123
 
 
5eb0bbd
 
 
 
 
 
 
208b123
 
 
 
 
 
 
03fb4d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208b123
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
---
base_model: Qwen/Qwen2-1.5B-Instruct
datasets:
- devanshamin/gem-viggo-function-calling
library_name: peft
license: apache-2.0
pipeline_tag: text-generation
tags:
- trl
- sft
- generated_from_trainer
model-index:
- name: Qwen2-1.5B-Instruct-Function-Calling-v1
  results: []
---

<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->

# Qwen2-1.5B-Instruct-Function-Calling-v1

This model is a fine-tuned version of [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) on [devanshamin/gem-viggo-function-calling](https://huggingface.co/datasets/devanshamin/gem-viggo-function-calling) dataset.

## Updated Chat Template
> Note: The template supports multiple tools but the model is fine-tuned on a dataset consisting of examples with a single tool.

- The chat template has been added to the [tokenizer_config.json](https://huggingface.co/devanshamin/Qwen2-1.5B-Instruct-Function-Calling-v1/blob/7ee7c020cefdb0101939469de608acc2afa7809e/tokenizer_config.json#L34).
- Supports prompts with and without tools.

```python
chat_template = (
  "{% for message in messages %}"
  "{% if loop.first and messages[0]['role'] != 'system' %}"
  "{% if tools %}"
  "<|im_start|>system\nYou are a helpful assistant with access to the following tools. Use them if required - \n"
  "```json\n{{ tools | tojson }}\n```<|im_end|>\n"
  "{% else %}"
  "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n"
  "{% endif %}"
  "{% endif %}"
  "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
  "{% endfor %}"
  "{% if add_generation_prompt %}"
  "{{ '<|im_start|>assistant\n' }}"
  "{% endif %}"
)
```

## Basic Usage

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "Qwen2-1.5B-Instruct-Function-Calling-v1"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

def inference(prompt: str) -> str:
  model_inputs = tokenizer([prompt], return_tensors="pt").to('cuda')
  generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
  generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  return response

messages = [{"role": "user", "content": "What is the speed of light?"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
response = inference(prompt)
print(response)
```

## Tool Usage

### Basic

```python
import json
from typing import List, Dict

def get_prompt(user_input: str, tools: List[Dict] | None = None):
  prompt = 'Extract the information from the following - \n{}'.format(user_input)
  messages = [{"role": "user", "content": prompt}]
  prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    tools=tools
  )
  return prompt

tool = {
  "type": "function",
  "function": {
    "name": "get_company_info",
    "description": "Correctly extracted company information with all the required parameters with correct types",
    "parameters": {
      "properties": {
        "name": {"title": "Name", "type": "string"},
        "investors": {
          "items": {"type": "string"},
          "title": "Investors",
          "type": "array"
        },
        "valuation": {"title": "Valuation", "type": "string"},
        "source": {"title": "Source", "type": "string"}
      },
      "required": ["investors", "name", "source", "valuation"],
      "type": "object"
    }
  }
}
input_text = "Founded in 2021, Pluto raised $4 million across multiple seed funding rounds, valuing the company at $12 million (pre-money), according to PitchBook. The startup was backed by investors including Switch Ventures, Caffeinated Capital and Maxime Seguineau."
prompt = get_prompt(input_text, tools=[tool])
response = inference(prompt)
print(response)
# ```json
# {
#   "name": "get_company_info",
#   "arguments": {
#     "name": "Pluto",
#     "investors": [
#       "Switch Ventures",
#       "Caffeinated Capital",
#       "Maxime Seguineau"
#     ],
#     "valuation": "$12 million",
#     "source": "PitchBook"
#   }
# }
# ```
```

### Advanced
```python
import re
from enum import Enum

from pydantic import BaseModel, Field # pip install pydantic
from instructor.function_calls import openai_schema # pip install instructor

# Define functions using pydantic classes
class PaperCategory(str, Enum):
  TYPE_1_DIABETES = 'Type 1 Diabetes'
  TYPE_2_DIABETES = 'Type 2 Diabetes'

class Classification(BaseModel):
  label: PaperCategory = Field(..., description='Provide the most likely category')
  reason: str = Field(..., description='Give a detailed explanation with quotes from the abstract explaining why the paper is related to the chosen label.')

function_definition = openai_schema(Classification).openai_schema
tool = dict(type='function', function=function_definition)
input_text = "1,25-dihydroxyvitamin D(3) (1,25(OH)(2)D(3)), the biologically active form of vitamin D, is widely recognized as a modulator of the immune system as well as a regulator of mineral metabolism. The objective of this study was to determine the effects of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice, a murine model of human type I diabetes. We have found that vitamin D-deficiency increases the incidence of diabetes in female mice from 46% (n=13) to 88% (n=8) and from 0% (n=10) to 44% (n=9) in male mice as of 200 days of age when compared to vitamin D-sufficient animals. Addition of 50 ng of 1,25(OH)(2)D(3)/day to the diet prevented disease onset as of 200 days and caused a significant rise in serum calcium levels, regardless of gender or vitamin D status. Our results indicate that vitamin D status is a determining factor of disease susceptibility and oral administration of 1,25(OH)(2)D(3) prevents diabetes onset in NOD mice through 200 days of age."
prompt = get_prompt(input_text, tools=[tool])
output = inference(prompt)
print(output)
# ```json
# {
#     "name": "Classification", 
#     "arguments": {
#         "label": "Type 1 Diabetes", 
#         "reason": "The study investigated the effect of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice. It also concluded that vitamin D deficiency leads to an increase in diabetes incidence and that the addition of 1,25(OH)(2)D(3) can prevent diabetes onset in NOD mice."
#     }
# }
# ```
# Extract JSON string using regex
output = re.search(r'```json\s*(\{.*?\})\s*```', output).group(1)
output = Classification(**json.loads(_output)['arguments'])
print(output)
# Classification(label=<PaperCategory.TYPE_1_DIABETES: 'Type 1 Diabetes'>, reason='The study investigated the effect of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice. It also concluded that vitamin D deficiency leads to an increase in diabetes incidence and that the addition of 1,25(OH)(2)D(3) can prevent diabetes onset in NOD mice.')
```

## Training procedure

### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 0.0001
- train_batch_size: 4
- eval_batch_size: 4
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: cosine
- lr_scheduler_warmup_steps: 10
- training_steps: 200

### Training results

| Training Loss | Epoch  | Step | Validation Loss |
|:-------------:|:------:|:----:|:---------------:|
| 0.4004        | 0.0101 | 20   | 0.4852          |
| 0.3624        | 0.0201 | 40   | 0.3221          |
| 0.2855        | 0.0302 | 60   | 0.2818          |
| 0.2652        | 0.0402 | 80   | 0.2592          |
| 0.2214        | 0.0503 | 100  | 0.2463          |
| 0.2471        | 0.0603 | 120  | 0.2358          |
| 0.2122        | 0.0704 | 140  | 0.2310          |
| 0.2048        | 0.0804 | 160  | 0.2275          |
| 0.2406        | 0.0905 | 180  | 0.2251          |
| 0.2445        | 0.1006 | 200  | 0.2248          |


### Framework versions

```text
peft==0.11.1
transformers==4.42.3
torch==2.3.1+cu121
datasets==2.20.0
tokenizers==0.19.1
```