File size: 4,674 Bytes
b581b1f
 
 
 
 
 
 
 
 
 
 
c14e393
b581b1f
 
 
 
 
 
 
 
 
 
 
 
a11f2aa
 
b581b1f
 
 
fd8a539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b581b1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
---
license: apache-2.0
language:
- en
library_name: transformers
pipeline_tag: text-generation
tags:
- Hare
datasets:
- cerebras/SlimPajama-627B
- HuggingFaceTB/cosmopedia
arxiv: 2406.11410v1
---

<a id="english"></a>

<p align="center">
<img width="400px" alt="Lite-AI" src="./logo.jpg">
</p>

</div>


## Hare-1.1B-Tool
- Hare-1.1B-Tool is a fine-tuned version of [Hare-1.1B-base](https://huggingface.co/LiteAI/Hare-1.1B-base), designed to enable the invocation of Android system APIs and tool orchestration in composite scenarios on mobile devices. For a detailed introduction, please refer to [Hare-1.1B-base](https://huggingface.co/LiteAI/Hare-1.1B-base).
- Hare-1.1B-Tool是由[Hare-1.1B-base](https://huggingface.co/LiteAI/Hare-1.1B-base)微调而来,用于在手机端实现安卓系统API调用和组合场景下的工具调用。详细介绍请看[Hare-1.1B-base](https://huggingface.co/LiteAI/Hare-1.1B-base)。

## 模型使用
```python
import time
from transformers import GenerationConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, StoppingCriteria, StoppingCriteriaList
import os
import json
import logging
import torch

log_path = 'your_log_path'
logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

logging.info('This is a log message.')

model_path = "/LiteAI/Hare-1.1B-Tool"
tokenizer = AutoTokenizer.from_pretrained(save_path)
model = AutoModelForCausalLM.from_pretrained(save_path)

class MyStoppingCriteria(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        keyword = tokenizer.decode(input_ids[0][-1])
        return keyword in ["<api_end>"]


def chat(
    messages,
    model,
    tokenizer,
    generate_config=None,
    max_length=512,
    max_new_tokens=256,
):
    if generate_config is None:
        generate_config = GenerationConfig(
            do_sample=False,
            max_length=max_length,
            max_new_tokens=max_new_tokens,
            eos_token_id=32001,
        )
    
    if messages[0]["role"] == "system":
        system = messages[0]["content"]
        messages = messages[0:]
    else:
        system = "You are a helpful assistant."
    
    n_token = max_length
    system = "<round_start>system\n{}<round_end>\n".format(system)
    system_token = tokenizer.encode(system, add_special_tokens=False)
    n_token -= len(system_token)

    query = messages[-1]["content"]
    query = "<round_start>user\n{}<round_end>\n<round_start>assistant\n".format(query)
    query_token = tokenizer.encode(query, add_special_tokens=False)
    n_token -= len(query_token)
    
    messages = messages[:-1]
    conversations = []
    for ids in range(len(messages)-1, 0, -2):
        user = messages[ids - 1]["content"]
        assistant = messages[ids]["content"]
        
        round = "<round_start>user\n{}<round_end>\n<round_start>assistant\n{}<round_end>\n".format(user, assistant)
        round_token = tokenizer.encode(round, add_special_tokens=False)

        if n_token - len(round_token) > 0:
            conversations = [round] + conversations
        else:
            break

    prompt = system + "".join(conversations) + query
    prompt_token = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    prompt_token.to(model.device)


    response = model.generate(
        generation_config=generate_config,
        **prompt_token
    )

    output_tokens = response[0].cpu().numpy()[prompt_token.input_ids.size()[1]:]
    output_string = tokenizer.decode(output_tokens, skip_special_tokens=True).replace("<round_end>", "")
    return output_string, prompt

# ======================
#       main
# ======================

test_query_path = "/home/sxw/sft_exper/dataset/query_to_test"
for file in os.listdir(test_query_path):
    file_pth = os.path.join(test_query_path, file)
    print(file_pth)
    logging.info(file_pth)
    with open(file_pth, 'r') as f:
        for line in f:
            data = json.loads(line)
            # print(data)
            query = data["human"]
            messages = [
                {"role": "system", "content": "Below is the query from the users, you need make full sense of user's intention based on the content of the sentence, then call the correct function and generate the parameters of the calling function."},
                {"role": "user", "content": query}
            ]
            response, input_prompt = chat(messages=messages, model=model, tokenizer=tokenizer)

            logging.info(query)
            logging.info(data["assistant"])
            logging.info(response)
```