ff_li
目录调整
f67d239
raw
history blame contribute delete
No virus
3.37 kB
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Optional, Union
from agent.log import logger
from agent.utils.utils import print_traceback
class FnCallNotImplError(NotImplementedError):
pass
class BaseChatModel(ABC):
def __init__(self):
self._support_fn_call: Optional[bool] = None
# It is okay to use the same code to handle the output
# regardless of whether stream is True or False, as follows:
# ```py
# for chunk in chat_model.chat(..., stream=True/False):
# response += chunk
# yield response
# ```
def chat(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
stream: bool = False,
) -> Union[str, Iterator[str]]:
if messages is None:
assert isinstance(prompt, str)
messages = [{'role': 'user', 'content': prompt}]
else:
assert prompt is None, 'Do not pass prompt and messages at the same time.'
logger.debug(messages)
if stream:
return self._chat_stream(messages, stop=stop)
else:
return self._chat_no_stream(messages, stop=stop)
def support_function_calling(self) -> bool:
if self._support_fn_call is None:
functions = [{
'name': 'get_current_weather',
'description': 'Get the current weather in a given location.',
'parameters': {
'type': 'object',
'properties': {
'location': {
'type':
'string',
'description':
'The city and state, e.g. San Francisco, CA',
},
'unit': {
'type': 'string',
'enum': ['celsius', 'fahrenheit'],
},
},
'required': ['location'],
},
}]
messages = [{
'role': 'user',
'content': 'What is the weather like in Boston?'
}]
self._support_fn_call = False
try:
response = self.chat_with_functions(messages=messages,
functions=functions)
if response.get('function_call', None):
logger.info('Support of function calling is detected.')
self._support_fn_call = True
except FnCallNotImplError:
pass
except Exception: # TODO: more specific
print_traceback()
return self._support_fn_call
def chat_with_functions(self,
messages: List[Dict],
functions: Optional[List[Dict]] = None) -> Dict:
raise FnCallNotImplError
@abstractmethod
def _chat_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> Iterator[str]:
raise NotImplementedError
@abstractmethod
def _chat_no_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> str:
raise NotImplementedError