File size: 3,576 Bytes
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
c07f594
 
 
 
 
 
 
 
 
 
 
 
 
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c07f594
 
 
 
 
 
 
 
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import datetime
import os
from typing import List
import warnings

from toolformers.base import Conversation, Toolformer, Tool
from camel.messages import BaseMessage
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType
from camel.messages import BaseMessage as bm
from camel.agents import ChatAgent
from camel.toolkits.function_tool import FunctionTool
from camel.configs.openai_config import ChatGPTConfig

from utils import register_cost

COSTS = {
    'gpt-4o': {
        'prompt_tokens': 2.5e-6,
        'completion_tokens': 10e-6
    },
    'gpt-4o-mini': {
        'prompt_tokens': 0.15e-6,
        'completion_tokens': 0.6e-6
    }
}

class CamelConversation(Conversation):
    def __init__(self, toolformer, agent, category=None):
        self.toolformer = toolformer
        self.agent = agent
        self.category = category
    
    def chat(self, message, role='user', print_output=True):
        agent_id = os.environ.get('AGENT_ID', None)

        start_time = datetime.datetime.now()

        if role == 'user':
            formatted_message = BaseMessage.make_user_message('user', message)
        elif role == 'assistant':
            formatted_message = BaseMessage.make_assistant_message('assistant', message)
        else:
            raise ValueError('Role must be either "user" or "assistant".')
        
        response = self.agent.step(formatted_message)

        if response.info.get('usage', None) is not None:
            usage_data = response.info['usage']

            total = 0
            for cost_name in ['prompt_tokens', 'completion_tokens']:
                total += COSTS[str(self.toolformer.model_type)][cost_name] * usage_data[cost_name]
            register_cost(self.category, total)

        reply = response.msg.content

        if print_output:
            print(reply)
        
        return reply

class CamelToolformer(Toolformer):
    def __init__(self, model_platform, model_type, model_config_dict, name=None):
        self.model_platform = model_platform
        self.model_type = model_type
        self.model_config_dict = model_config_dict
        self._name = name

    @property
    def name(self):
        if self._name is None:
            return f'{self.model_platform.value}_{self.model_type.value}'
        else:
            return self._name

    def new_conversation(self, prompt, tools : List[Tool], category=None) -> Conversation:
        model = ModelFactory.create(
            model_platform=self.model_platform,
            model_type=self.model_type,
            model_config_dict=self.model_config_dict
        )

        agent = ChatAgent(
            model=model,
            system_message=bm.make_assistant_message('system', prompt),
            tools=[FunctionTool(tool.call_tool_for_toolformer, openai_tool_schema=tool.as_openai_info()) for tool in tools]
        )

        return CamelConversation(self, agent, category)

def make_openai_toolformer(model_type_internal):
    if model_type_internal == 'gpt-4o':
        model_type = ModelType.GPT_4O
    elif model_type_internal == 'gpt-4o-mini':
        model_type = ModelType.GPT_4O_MINI
    else:
        raise ValueError('Model type must be either "gpt-4o" or "gpt-4o-mini".')

    #formatted_tools = [FunctionTool(tool.call_tool_for_toolformer, tool.as_openai_info()) for tool in tools]

    return CamelToolformer(
        model_platform=ModelPlatformType.OPENAI,
        model_type=model_type,
        model_config_dict=ChatGPTConfig(temperature=0.2).as_dict(),
        name=model_type_internal
    )