Spaces:
Runtime error
Runtime error
# + tags=["hide_inp"] | |
desc = """ | |
### Agent | |
Chain that executes different tools based on model decisions. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb) | |
(Adapted from LangChain ) | |
""" | |
# - | |
# $ | |
from minichain import Id, prompt, OpenAI, show, transform, Mock, Break | |
from gradio_tools.tools import StableDiffusionTool, ImageCaptioningTool, ImageToMusicTool | |
# class ImageCaptioningTool: | |
# def run(self, inp): | |
# return "This is a picture of a smiling huggingface logo." | |
# description = "Image Captioning" | |
tools = [StableDiffusionTool(), ImageCaptioningTool(), ImageToMusicTool()] | |
def agent(model, query, history): | |
return model(dict(tools=[(str(tool.__class__.__name__), tool.description) | |
for tool in tools], | |
input=query, | |
agent_scratchpad=history | |
)) | |
def tool_parse(out): | |
lines = out.split("\n") | |
if lines[0].split("?")[-1].strip() == "Yes": | |
tool = lines[1].split(":", 1)[-1].strip() | |
command = lines[2].split(":", 1)[-1].strip() | |
return tool, command | |
else: | |
return Break() | |
def tool_use(model, usage): | |
selector, command = usage | |
for i, tool in enumerate(tools): | |
if selector == tool.__class__.__name__: | |
return model(command, tool_num=i) | |
return ("",) | |
def append(history, new, observation): | |
return history + "\n" + new + "Observation: " + observation | |
def run(query): | |
history = "" | |
observations = [] | |
for i in range(3): | |
select_input = agent(query, history) | |
observations.append(tool_use(tool_parse(select_input))) | |
history = append(history, select_input, observations[i]) | |
return observations[-1] | |
# $ | |
gradio = show(run, | |
subprompts=[agent, tool_use] * 3, | |
examples=[ | |
"I would please like a photo of a dog riding a skateboard. " | |
"Please caption this image and create a song for it.", | |
'Use an image generator tool to draw a cat.', | |
'Caption the image https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png from the internet'], | |
out_type="markdown", | |
description=desc, | |
show_advanced=False | |
) | |
if __name__ == "__main__": | |
gradio.queue().launch() | |