|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib.util |
|
import json |
|
import os |
|
import time |
|
from dataclasses import dataclass |
|
from typing import Dict |
|
|
|
import requests |
|
from huggingface_hub import HfFolder, hf_hub_download, list_spaces |
|
|
|
from ..models.auto import AutoTokenizer |
|
from ..utils import is_offline_mode, is_openai_available, is_torch_available, logging |
|
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote |
|
from .prompts import CHAT_MESSAGE_PROMPT, download_prompt |
|
from .python_interpreter import evaluate |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
if is_openai_available(): |
|
import openai |
|
|
|
if is_torch_available(): |
|
from ..generation import StoppingCriteria, StoppingCriteriaList |
|
from ..models.auto import AutoModelForCausalLM |
|
else: |
|
StoppingCriteria = object |
|
|
|
_tools_are_initialized = False |
|
|
|
|
|
BASE_PYTHON_TOOLS = { |
|
"print": print, |
|
"range": range, |
|
"float": float, |
|
"int": int, |
|
"bool": bool, |
|
"str": str, |
|
} |
|
|
|
|
|
@dataclass |
|
class PreTool: |
|
task: str |
|
description: str |
|
repo_id: str |
|
|
|
|
|
HUGGINGFACE_DEFAULT_TOOLS = {} |
|
|
|
|
|
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [ |
|
"image-transformation", |
|
"text-download", |
|
"text-to-image", |
|
"text-to-video", |
|
] |
|
|
|
|
|
def get_remote_tools(organization="huggingface-tools"): |
|
if is_offline_mode(): |
|
logger.info("You are in offline mode, so remote tools are not available.") |
|
return {} |
|
|
|
spaces = list_spaces(author=organization) |
|
tools = {} |
|
for space_info in spaces: |
|
repo_id = space_info.id |
|
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space") |
|
with open(resolved_config_file, encoding="utf-8") as reader: |
|
config = json.load(reader) |
|
|
|
task = repo_id.split("/")[-1] |
|
tools[config["name"]] = PreTool(task=task, description=config["description"], repo_id=repo_id) |
|
|
|
return tools |
|
|
|
|
|
def _setup_default_tools(): |
|
global HUGGINGFACE_DEFAULT_TOOLS |
|
global _tools_are_initialized |
|
|
|
if _tools_are_initialized: |
|
return |
|
|
|
main_module = importlib.import_module("transformers") |
|
tools_module = main_module.tools |
|
|
|
remote_tools = get_remote_tools() |
|
for task_name, tool_class_name in TASK_MAPPING.items(): |
|
tool_class = getattr(tools_module, tool_class_name) |
|
description = tool_class.description |
|
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None) |
|
|
|
if not is_offline_mode(): |
|
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB: |
|
found = False |
|
for tool_name, tool in remote_tools.items(): |
|
if tool.task == task_name: |
|
HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool |
|
found = True |
|
break |
|
|
|
if not found: |
|
raise ValueError(f"{task_name} is not implemented on the Hub.") |
|
|
|
_tools_are_initialized = True |
|
|
|
|
|
def resolve_tools(code, toolbox, remote=False, cached_tools=None): |
|
if cached_tools is None: |
|
resolved_tools = BASE_PYTHON_TOOLS.copy() |
|
else: |
|
resolved_tools = cached_tools |
|
for name, tool in toolbox.items(): |
|
if name not in code or name in resolved_tools: |
|
continue |
|
|
|
if isinstance(tool, Tool): |
|
resolved_tools[name] = tool |
|
else: |
|
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id |
|
_remote = remote and supports_remote(task_or_repo_id) |
|
resolved_tools[name] = load_tool(task_or_repo_id, remote=_remote) |
|
|
|
return resolved_tools |
|
|
|
|
|
def get_tool_creation_code(code, toolbox, remote=False): |
|
code_lines = ["from transformers import load_tool", ""] |
|
for name, tool in toolbox.items(): |
|
if name not in code or isinstance(tool, Tool): |
|
continue |
|
|
|
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id |
|
line = f'{name} = load_tool("{task_or_repo_id}"' |
|
if remote: |
|
line += ", remote=True" |
|
line += ")" |
|
code_lines.append(line) |
|
|
|
return "\n".join(code_lines) + "\n" |
|
|
|
|
|
def clean_code_for_chat(result): |
|
lines = result.split("\n") |
|
idx = 0 |
|
while idx < len(lines) and not lines[idx].lstrip().startswith("```"): |
|
idx += 1 |
|
explanation = "\n".join(lines[:idx]).strip() |
|
if idx == len(lines): |
|
return explanation, None |
|
|
|
idx += 1 |
|
start_idx = idx |
|
while not lines[idx].lstrip().startswith("```"): |
|
idx += 1 |
|
code = "\n".join(lines[start_idx:idx]).strip() |
|
|
|
return explanation, code |
|
|
|
|
|
def clean_code_for_run(result): |
|
result = f"I will use the following {result}" |
|
explanation, code = result.split("Answer:") |
|
explanation = explanation.strip() |
|
code = code.strip() |
|
|
|
code_lines = code.split("\n") |
|
if code_lines[0] in ["```", "```py", "```python"]: |
|
code_lines = code_lines[1:] |
|
if code_lines[-1] == "```": |
|
code_lines = code_lines[:-1] |
|
code = "\n".join(code_lines) |
|
|
|
return explanation, code |
|
|
|
|
|
class Agent: |
|
""" |
|
Base class for all agents which contains the main API methods. |
|
|
|
Args: |
|
chat_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`chat_prompt_template.txt` in this repo in this case. |
|
run_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`run_prompt_template.txt` in this repo in this case. |
|
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
|
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
|
one of the default tools, that default tool will be overridden. |
|
""" |
|
|
|
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None): |
|
_setup_default_tools() |
|
|
|
agent_name = self.__class__.__name__ |
|
self.chat_prompt_template = download_prompt(chat_prompt_template, agent_name, mode="chat") |
|
self.run_prompt_template = download_prompt(run_prompt_template, agent_name, mode="run") |
|
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy() |
|
self.log = print |
|
if additional_tools is not None: |
|
if isinstance(additional_tools, (list, tuple)): |
|
additional_tools = {t.name: t for t in additional_tools} |
|
elif not isinstance(additional_tools, dict): |
|
additional_tools = {additional_tools.name: additional_tools} |
|
|
|
replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS} |
|
self._toolbox.update(additional_tools) |
|
if len(replacements) > 1: |
|
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()]) |
|
logger.warning( |
|
f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}." |
|
) |
|
elif len(replacements) == 1: |
|
name = list(replacements.keys())[0] |
|
logger.warning(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.") |
|
|
|
self.prepare_for_new_chat() |
|
|
|
@property |
|
def toolbox(self) -> Dict[str, Tool]: |
|
"""Get all tool currently available to the agent""" |
|
return self._toolbox |
|
|
|
def format_prompt(self, task, chat_mode=False): |
|
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()]) |
|
if chat_mode: |
|
if self.chat_history is None: |
|
prompt = self.chat_prompt_template.replace("<<all_tools>>", description) |
|
else: |
|
prompt = self.chat_history |
|
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task) |
|
else: |
|
prompt = self.run_prompt_template.replace("<<all_tools>>", description) |
|
prompt = prompt.replace("<<prompt>>", task) |
|
return prompt |
|
|
|
def set_stream(self, streamer): |
|
""" |
|
Set the function use to stream results (which is `print` by default). |
|
|
|
Args: |
|
streamer (`callable`): The function to call when streaming results from the LLM. |
|
""" |
|
self.log = streamer |
|
|
|
def chat(self, task, *, return_code=False, remote=False, **kwargs): |
|
""" |
|
Sends a new request to the agent in a chat. Will use the previous ones in its history. |
|
|
|
Args: |
|
task (`str`): The task to perform |
|
return_code (`bool`, *optional*, defaults to `False`): |
|
Whether to just return code and not evaluate it. |
|
remote (`bool`, *optional*, defaults to `False`): |
|
Whether or not to use remote tools (inference endpoints) instead of local ones. |
|
kwargs (additional keyword arguments, *optional*): |
|
Any keyword argument to send to the agent when evaluating the code. |
|
|
|
Example: |
|
|
|
```py |
|
from transformers import HfAgent |
|
|
|
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder") |
|
agent.chat("Draw me a picture of rivers and lakes") |
|
|
|
agent.chat("Transform the picture so that there is a rock in there") |
|
``` |
|
""" |
|
prompt = self.format_prompt(task, chat_mode=True) |
|
result = self.generate_one(prompt, stop=["Human:", "====="]) |
|
self.chat_history = prompt + result.strip() + "\n" |
|
explanation, code = clean_code_for_chat(result) |
|
|
|
self.log(f"==Explanation from the agent==\n{explanation}") |
|
|
|
if code is not None: |
|
self.log(f"\n\n==Code generated by the agent==\n{code}") |
|
if not return_code: |
|
self.log("\n\n==Result==") |
|
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools) |
|
self.chat_state.update(kwargs) |
|
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True) |
|
else: |
|
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) |
|
return f"{tool_code}\n{code}" |
|
|
|
def prepare_for_new_chat(self): |
|
""" |
|
Clears the history of prior calls to [`~Agent.chat`]. |
|
""" |
|
self.chat_history = None |
|
self.chat_state = {} |
|
self.cached_tools = None |
|
|
|
def run(self, task, *, return_code=False, remote=False, **kwargs): |
|
""" |
|
Sends a request to the agent. |
|
|
|
Args: |
|
task (`str`): The task to perform |
|
return_code (`bool`, *optional*, defaults to `False`): |
|
Whether to just return code and not evaluate it. |
|
remote (`bool`, *optional*, defaults to `False`): |
|
Whether or not to use remote tools (inference endpoints) instead of local ones. |
|
kwargs (additional keyword arguments, *optional*): |
|
Any keyword argument to send to the agent when evaluating the code. |
|
|
|
Example: |
|
|
|
```py |
|
from transformers import HfAgent |
|
|
|
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder") |
|
agent.run("Draw me a picture of rivers and lakes") |
|
``` |
|
""" |
|
prompt = self.format_prompt(task) |
|
result = self.generate_one(prompt, stop=["Task:"]) |
|
explanation, code = clean_code_for_run(result) |
|
|
|
self.log(f"==Explanation from the agent==\n{explanation}") |
|
|
|
self.log(f"\n\n==Code generated by the agent==\n{code}") |
|
if not return_code: |
|
self.log("\n\n==Result==") |
|
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools) |
|
return evaluate(code, self.cached_tools, state=kwargs.copy()) |
|
else: |
|
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) |
|
return f"{tool_code}\n{code}" |
|
|
|
def generate_one(self, prompt, stop): |
|
|
|
raise NotImplementedError |
|
|
|
def generate_many(self, prompts, stop): |
|
|
|
return [self.generate_one(prompt, stop) for prompt in prompts] |
|
|
|
|
|
class OpenAiAgent(Agent): |
|
""" |
|
Agent that uses the openai API to generate code. |
|
|
|
<Tip warning={true}> |
|
|
|
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like |
|
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version. |
|
|
|
</Tip> |
|
|
|
Args: |
|
model (`str`, *optional*, defaults to `"text-davinci-003"`): |
|
The name of the OpenAI model to use. |
|
api_key (`str`, *optional*): |
|
The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`. |
|
chat_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`chat_prompt_template.txt` in this repo in this case. |
|
run_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`run_prompt_template.txt` in this repo in this case. |
|
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
|
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
|
one of the default tools, that default tool will be overridden. |
|
|
|
Example: |
|
|
|
```py |
|
from transformers import OpenAiAgent |
|
|
|
agent = OpenAiAgent(model="text-davinci-003", api_key=xxx) |
|
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!") |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model="text-davinci-003", |
|
api_key=None, |
|
chat_prompt_template=None, |
|
run_prompt_template=None, |
|
additional_tools=None, |
|
): |
|
if not is_openai_available(): |
|
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.") |
|
|
|
if api_key is None: |
|
api_key = os.environ.get("OPENAI_API_KEY", None) |
|
if api_key is None: |
|
raise ValueError( |
|
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here " |
|
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = " |
|
"xxx." |
|
) |
|
else: |
|
openai.api_key = api_key |
|
self.model = model |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
|
|
def generate_many(self, prompts, stop): |
|
if "gpt" in self.model: |
|
return [self._chat_generate(prompt, stop) for prompt in prompts] |
|
else: |
|
return self._completion_generate(prompts, stop) |
|
|
|
def generate_one(self, prompt, stop): |
|
if "gpt" in self.model: |
|
return self._chat_generate(prompt, stop) |
|
else: |
|
return self._completion_generate([prompt], stop)[0] |
|
|
|
def _chat_generate(self, prompt, stop): |
|
result = openai.ChatCompletion.create( |
|
model=self.model, |
|
messages=[{"role": "user", "content": prompt}], |
|
temperature=0, |
|
stop=stop, |
|
) |
|
return result["choices"][0]["message"]["content"] |
|
|
|
def _completion_generate(self, prompts, stop): |
|
result = openai.Completion.create( |
|
model=self.model, |
|
prompt=prompts, |
|
temperature=0, |
|
stop=stop, |
|
max_tokens=200, |
|
) |
|
return [answer["text"] for answer in result["choices"]] |
|
|
|
|
|
class AzureOpenAiAgent(Agent): |
|
""" |
|
Agent that uses Azure OpenAI to generate code. See the [official |
|
documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI |
|
model on Azure |
|
|
|
<Tip warning={true}> |
|
|
|
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like |
|
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version. |
|
|
|
</Tip> |
|
|
|
Args: |
|
deployment_id (`str`): |
|
The name of the deployed Azure openAI model to use. |
|
api_key (`str`, *optional*): |
|
The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`. |
|
resource_name (`str`, *optional*): |
|
The name of your Azure OpenAI Resource. If unset, will look for the environment variable |
|
`"AZURE_OPENAI_RESOURCE_NAME"`. |
|
api_version (`str`, *optional*, default to `"2022-12-01"`): |
|
The API version to use for this agent. |
|
is_chat_mode (`bool`, *optional*): |
|
Whether you are using a completion model or a chat model (see note above, chat models won't be as |
|
efficient). Will default to `gpt` being in the `deployment_id` or not. |
|
chat_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`chat_prompt_template.txt` in this repo in this case. |
|
run_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`run_prompt_template.txt` in this repo in this case. |
|
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
|
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
|
one of the default tools, that default tool will be overridden. |
|
|
|
Example: |
|
|
|
```py |
|
from transformers import AzureOpenAiAgent |
|
|
|
agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy) |
|
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!") |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
deployment_id, |
|
api_key=None, |
|
resource_name=None, |
|
api_version="2022-12-01", |
|
is_chat_model=None, |
|
chat_prompt_template=None, |
|
run_prompt_template=None, |
|
additional_tools=None, |
|
): |
|
if not is_openai_available(): |
|
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.") |
|
|
|
self.deployment_id = deployment_id |
|
openai.api_type = "azure" |
|
if api_key is None: |
|
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None) |
|
if api_key is None: |
|
raise ValueError( |
|
"You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with " |
|
"`os.environ['AZURE_OPENAI_API_KEY'] = xxx." |
|
) |
|
else: |
|
openai.api_key = api_key |
|
if resource_name is None: |
|
resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME", None) |
|
if resource_name is None: |
|
raise ValueError( |
|
"You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with " |
|
"`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx." |
|
) |
|
else: |
|
openai.api_base = f"https://{resource_name}.openai.azure.com" |
|
openai.api_version = api_version |
|
|
|
if is_chat_model is None: |
|
is_chat_model = "gpt" in deployment_id.lower() |
|
self.is_chat_model = is_chat_model |
|
|
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
|
|
def generate_many(self, prompts, stop): |
|
if self.is_chat_model: |
|
return [self._chat_generate(prompt, stop) for prompt in prompts] |
|
else: |
|
return self._completion_generate(prompts, stop) |
|
|
|
def generate_one(self, prompt, stop): |
|
if self.is_chat_model: |
|
return self._chat_generate(prompt, stop) |
|
else: |
|
return self._completion_generate([prompt], stop)[0] |
|
|
|
def _chat_generate(self, prompt, stop): |
|
result = openai.ChatCompletion.create( |
|
engine=self.deployment_id, |
|
messages=[{"role": "user", "content": prompt}], |
|
temperature=0, |
|
stop=stop, |
|
) |
|
return result["choices"][0]["message"]["content"] |
|
|
|
def _completion_generate(self, prompts, stop): |
|
result = openai.Completion.create( |
|
engine=self.deployment_id, |
|
prompt=prompts, |
|
temperature=0, |
|
stop=stop, |
|
max_tokens=200, |
|
) |
|
return [answer["text"] for answer in result["choices"]] |
|
|
|
|
|
class HfAgent(Agent): |
|
""" |
|
Agent that uses an inference endpoint to generate code. |
|
|
|
Args: |
|
url_endpoint (`str`): |
|
The name of the url endpoint to use. |
|
token (`str`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when |
|
running `huggingface-cli login` (stored in `~/.huggingface`). |
|
chat_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`chat_prompt_template.txt` in this repo in this case. |
|
run_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`run_prompt_template.txt` in this repo in this case. |
|
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
|
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
|
one of the default tools, that default tool will be overridden. |
|
|
|
Example: |
|
|
|
```py |
|
from transformers import HfAgent |
|
|
|
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder") |
|
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!") |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None |
|
): |
|
self.url_endpoint = url_endpoint |
|
if token is None: |
|
self.token = f"Bearer {HfFolder().get_token()}" |
|
elif token.startswith("Bearer") or token.startswith("Basic"): |
|
self.token = token |
|
else: |
|
self.token = f"Bearer {token}" |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
|
|
def generate_one(self, prompt, stop): |
|
headers = {"Authorization": self.token} |
|
inputs = { |
|
"inputs": prompt, |
|
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop}, |
|
} |
|
|
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers) |
|
if response.status_code == 429: |
|
logger.info("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self._generate_one(prompt) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Error {response.status_code}: {response.json()}") |
|
|
|
result = response.json()[0]["generated_text"] |
|
|
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|
|
|
|
class LocalAgent(Agent): |
|
""" |
|
Agent that uses a local model and tokenizer to generate code. |
|
|
|
Args: |
|
model ([`PreTrainedModel`]): |
|
The model to use for the agent. |
|
tokenizer ([`PreTrainedTokenizer`]): |
|
The tokenizer to use for the agent. |
|
chat_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`chat_prompt_template.txt` in this repo in this case. |
|
run_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`run_prompt_template.txt` in this repo in this case. |
|
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
|
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
|
one of the default tools, that default tool will be overridden. |
|
|
|
Example: |
|
|
|
```py |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent |
|
|
|
checkpoint = "bigcode/starcoder" |
|
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16) |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
|
agent = LocalAgent(model, tokenizer) |
|
agent.run("Draw me a picture of rivers and lakes.") |
|
``` |
|
""" |
|
|
|
def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
""" |
|
Convenience method to build a `LocalAgent` from a pretrained checkpoint. |
|
|
|
Args: |
|
pretrained_model_name_or_path (`str` or `os.PathLike`): |
|
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer. |
|
kwargs (`Dict[str, Any]`, *optional*): |
|
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`]. |
|
|
|
Example: |
|
|
|
```py |
|
import torch |
|
from transformers import LocalAgent |
|
|
|
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16) |
|
agent.run("Draw me a picture of rivers and lakes.") |
|
``` |
|
""" |
|
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
return cls(model, tokenizer) |
|
|
|
@property |
|
def _model_device(self): |
|
if hasattr(self.model, "hf_device_map"): |
|
return list(self.model.hf_device_map.values())[0] |
|
for param in self.model.parameters(): |
|
return param.device |
|
|
|
def generate_one(self, prompt, stop): |
|
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device) |
|
src_len = encoded_inputs["input_ids"].shape[1] |
|
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)]) |
|
outputs = self.model.generate( |
|
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria |
|
) |
|
|
|
result = self.tokenizer.decode(outputs[0].tolist()[src_len:]) |
|
|
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
result = result[: -len(stop_seq)] |
|
return result |
|
|
|
|
|
class StopSequenceCriteria(StoppingCriteria): |
|
""" |
|
This class can be used to stop generation whenever a sequence of tokens is encountered. |
|
|
|
Args: |
|
stop_sequences (`str` or `List[str]`): |
|
The sequence (or list of sequences) on which to stop execution. |
|
tokenizer: |
|
The tokenizer used to decode the model outputs. |
|
""" |
|
|
|
def __init__(self, stop_sequences, tokenizer): |
|
if isinstance(stop_sequences, str): |
|
stop_sequences = [stop_sequences] |
|
self.stop_sequences = stop_sequences |
|
self.tokenizer = tokenizer |
|
|
|
def __call__(self, input_ids, scores, **kwargs) -> bool: |
|
decoded_output = self.tokenizer.decode(input_ids.tolist()[0]) |
|
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences) |
|
|