Spaces:
Running
Running
add file
Browse files- .gitignore +4 -0
- app.py +690 -0
- assets/assistant.png +0 -0
- assets/human.png +0 -0
- builtin_plan.json +15 -0
- cllm/agents/__init__.py +2 -0
- cllm/agents/base.py +173 -0
- cllm/agents/builtin/__init__.py +3 -0
- cllm/agents/builtin/plans.py +634 -0
- cllm/agents/builtin/tools.py +1512 -0
- cllm/agents/container.py +98 -0
- cllm/agents/tog/__init__.py +2 -0
- cllm/agents/tog/compiler.py +62 -0
- cllm/agents/tog/controller.py +157 -0
- cllm/agents/tog/interpretor.py +262 -0
- cllm/agents/tog/planner.py +156 -0
- cllm/agents/tog/responser.py +66 -0
- cllm/services/audio/__init__.py +0 -0
- cllm/services/audio/api.py +140 -0
- cllm/services/general/__init__.py +0 -0
- cllm/services/general/api.py +65 -0
- cllm/services/image_editing/__init__.py +0 -0
- cllm/services/image_editing/api.py +277 -0
- cllm/services/image_generation/__init__.py +0 -0
- cllm/services/image_generation/api.py +96 -0
- cllm/services/image_inpainting/__init__.py +0 -0
- cllm/services/image_inpainting/api.py +76 -0
- cllm/services/image_perception/__init__.py +0 -0
- cllm/services/image_perception/api.py +202 -0
- cllm/services/image_processing/__init__.py +0 -0
- cllm/services/image_processing/api.py +63 -0
- cllm/services/nlp/__init__.py +0 -0
- cllm/services/nlp/api.py +163 -0
- cllm/services/nlp/llms/__init__.py +2 -0
- cllm/services/nlp/llms/chat_models.py +219 -0
- cllm/services/nlp/llms/memory/__init__.py +1 -0
- cllm/services/nlp/llms/memory/message_memory.py +131 -0
- cllm/services/nlp/llms/memory/utils.py +52 -0
- cllm/services/tog/__init__.py +2 -0
- cllm/services/tog/api.py +40 -0
- cllm/services/utils.py +50 -0
- cllm/services/video/__init__.py +0 -0
- cllm/services/video/api.py +135 -0
- cllm/services/vqa/__init__.py +0 -0
- cllm/services/vqa/api.py +28 -0
- cllm/utils.py +79 -0
- requirements.txt +14 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
run.sh
|
3 |
+
client_resources/
|
4 |
+
cllm.log
|
app.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
from functools import partial
|
6 |
+
from pydoc import locate
|
7 |
+
import shutil
|
8 |
+
import json
|
9 |
+
from traceback import print_exc
|
10 |
+
import uuid
|
11 |
+
from pathlib import Path
|
12 |
+
from collections import OrderedDict
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
import whisper
|
17 |
+
import fire
|
18 |
+
import gradio as gr
|
19 |
+
import gradio.themes.base as ThemeBase
|
20 |
+
from gradio.themes.utils import colors, fonts, sizes
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
|
24 |
+
sys.path.append(os.getcwd())
|
25 |
+
from cllm.agents.builtin import plans
|
26 |
+
from cllm.services.general.api import remote_logging
|
27 |
+
from cllm.agents import container, FILE_EXT
|
28 |
+
from cllm.utils import get_real_path, plain2md, md2plain
|
29 |
+
import openai
|
30 |
+
|
31 |
+
openai.api_base = os.environ.get("OPENAI_API_BASE", None)
|
32 |
+
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
|
33 |
+
|
34 |
+
|
35 |
+
logging.basicConfig(
|
36 |
+
filename="cllm.log",
|
37 |
+
level=logging.INFO,
|
38 |
+
format="%(asctime)s %(levelname)-8s %(message)s",
|
39 |
+
)
|
40 |
+
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
RESOURCE_ROOT = os.environ.get("CLIENT_ROOT", "./client_resources")
|
44 |
+
|
45 |
+
|
46 |
+
def is_image(file_path):
|
47 |
+
ext = FILE_EXT["image"]
|
48 |
+
_, extension = os.path.splitext(file_path)
|
49 |
+
return extension[1:] in ext
|
50 |
+
|
51 |
+
|
52 |
+
def is_video(file_path):
|
53 |
+
ext = FILE_EXT["video"]
|
54 |
+
_, extension = os.path.splitext(file_path)
|
55 |
+
return extension[1:] in ext
|
56 |
+
|
57 |
+
|
58 |
+
def is_audio(file_path):
|
59 |
+
ext = FILE_EXT["audio"]
|
60 |
+
_, extension = os.path.splitext(file_path)
|
61 |
+
return extension[1:] in ext
|
62 |
+
|
63 |
+
|
64 |
+
def get_file_type(file_path):
|
65 |
+
if is_image(file_path):
|
66 |
+
if "mask" in file_path:
|
67 |
+
return "mask"
|
68 |
+
return "image"
|
69 |
+
elif is_video(file_path):
|
70 |
+
return "video"
|
71 |
+
elif is_audio(file_path):
|
72 |
+
return "audio"
|
73 |
+
raise ValueError("Invalid file type")
|
74 |
+
|
75 |
+
|
76 |
+
def convert_dict_to_frame(data):
|
77 |
+
import pandas
|
78 |
+
|
79 |
+
outputs = []
|
80 |
+
for k, v in data.items():
|
81 |
+
output = {"Resource": k}
|
82 |
+
if not isinstance(v, str):
|
83 |
+
output["Type"] = str(v.__class__)
|
84 |
+
else:
|
85 |
+
output["Type"] = v
|
86 |
+
outputs.append(output)
|
87 |
+
if len(outputs) == 0:
|
88 |
+
return None
|
89 |
+
return pandas.DataFrame(outputs)
|
90 |
+
|
91 |
+
|
92 |
+
class Seafoam(ThemeBase.Base):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
*,
|
96 |
+
primary_hue=colors.emerald,
|
97 |
+
secondary_hue=colors.blue,
|
98 |
+
neutral_hue=colors.gray,
|
99 |
+
spacing_size=sizes.spacing_md,
|
100 |
+
radius_size=sizes.radius_md,
|
101 |
+
text_size=sizes.text_sm,
|
102 |
+
):
|
103 |
+
super().__init__(
|
104 |
+
primary_hue=primary_hue,
|
105 |
+
secondary_hue=secondary_hue,
|
106 |
+
neutral_hue=neutral_hue,
|
107 |
+
spacing_size=spacing_size,
|
108 |
+
radius_size=radius_size,
|
109 |
+
text_size=text_size,
|
110 |
+
)
|
111 |
+
super().set(
|
112 |
+
body_background_fill_dark="#111111",
|
113 |
+
button_primary_background_fill="*primary_300",
|
114 |
+
button_primary_background_fill_hover="*primary_200",
|
115 |
+
button_primary_text_color="black",
|
116 |
+
button_secondary_background_fill="*secondary_300",
|
117 |
+
button_secondary_background_fill_hover="*secondary_200",
|
118 |
+
border_color_primary="#0BB9BF",
|
119 |
+
slider_color="*secondary_300",
|
120 |
+
slider_color_dark="*secondary_600",
|
121 |
+
block_title_text_weight="600",
|
122 |
+
block_border_width="3px",
|
123 |
+
block_shadow="*shadow_drop_lg",
|
124 |
+
button_shadow="*shadow_drop_lg",
|
125 |
+
button_large_padding="10px",
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
class InteractionLoop:
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
controller="cllm.agents.code.Controller",
|
133 |
+
):
|
134 |
+
self.stream = True
|
135 |
+
Controller = locate(controller)
|
136 |
+
self.controller = Controller(stream=self.stream, interpretor_kwargs=dict())
|
137 |
+
self.whisper = whisper.load_model("base")
|
138 |
+
|
139 |
+
def _gen_new_name(self, r_type, ext="png"):
|
140 |
+
this_new_uuid = str(uuid.uuid4())[:6]
|
141 |
+
new_file_name = f"{this_new_uuid}_{r_type}.{ext}"
|
142 |
+
return new_file_name
|
143 |
+
|
144 |
+
def init_state(self):
|
145 |
+
user_state = OrderedDict()
|
146 |
+
user_state["resources"] = OrderedDict()
|
147 |
+
user_state["history_msgs"] = []
|
148 |
+
resources = OrderedDict()
|
149 |
+
for item in sorted(os.listdir("./assets/resources")):
|
150 |
+
if item.startswith("."):
|
151 |
+
continue
|
152 |
+
shutil.copy(
|
153 |
+
osp.join("./assets/resources", item),
|
154 |
+
osp.join(RESOURCE_ROOT, item),
|
155 |
+
)
|
156 |
+
resources[item] = get_file_type(item)
|
157 |
+
# return user_state, user_state["resources"]
|
158 |
+
return user_state, resources
|
159 |
+
|
160 |
+
def add_file(self, user_state, history, file):
|
161 |
+
if user_state.get("resources", None) is None:
|
162 |
+
user_state["resources"] = OrderedDict()
|
163 |
+
|
164 |
+
if file is None:
|
165 |
+
return user_state, None, history, None
|
166 |
+
# filename = os.path.basename(file.name)
|
167 |
+
file = Path(file)
|
168 |
+
ext = file.suffix[1:]
|
169 |
+
if ext in FILE_EXT["image"]:
|
170 |
+
ext = "png"
|
171 |
+
r_type = get_file_type(file.name)
|
172 |
+
new_filename = self._gen_new_name(r_type, ext)
|
173 |
+
saved_path = get_real_path(new_filename)
|
174 |
+
if ext in FILE_EXT["image"]:
|
175 |
+
Image.open(file).convert("RGB").save(saved_path, "png")
|
176 |
+
user_state["input_image"] = new_filename
|
177 |
+
else:
|
178 |
+
shutil.copy(file, saved_path)
|
179 |
+
logger.info(f"add file: {saved_path}")
|
180 |
+
user_state["resources"][new_filename] = r_type
|
181 |
+
for key, val in user_state["resources"].items():
|
182 |
+
if key == "prompt_points":
|
183 |
+
user_state["resources"].pop(key)
|
184 |
+
break
|
185 |
+
history, _ = self.add_text(history, (saved_path,), role="human", append=False)
|
186 |
+
history, _ = self.add_text(
|
187 |
+
history, f"Recieved file {new_filename}", role="assistant", append=False
|
188 |
+
)
|
189 |
+
memory = convert_dict_to_frame(user_state["resources"])
|
190 |
+
image_name = None
|
191 |
+
if Path(saved_path).suffix[1:] in FILE_EXT["image"]:
|
192 |
+
image_name = saved_path
|
193 |
+
return user_state, image_name, history, memory
|
194 |
+
|
195 |
+
def add_msg(self, history, text, audio, role="assistant", append=False):
|
196 |
+
if text is not None and text.strip() != "":
|
197 |
+
return self.add_text(history, text, role=role, append=append)
|
198 |
+
elif audio is not None:
|
199 |
+
return self.add_audio(history, audio, role=role, append=append)
|
200 |
+
return history, ""
|
201 |
+
|
202 |
+
def add_text(self, history, text, role="assistant", append=False):
|
203 |
+
if history is None:
|
204 |
+
return history, ""
|
205 |
+
assert role in ["human", "assistant"]
|
206 |
+
idx = 0
|
207 |
+
if len(history) == 0 or role == "human":
|
208 |
+
history.append([None, None])
|
209 |
+
if role == "assistant":
|
210 |
+
idx = 1
|
211 |
+
if not append and history[-1][1] is not None:
|
212 |
+
history.append([None, None])
|
213 |
+
|
214 |
+
if append:
|
215 |
+
history[-1][idx] = (
|
216 |
+
text if history[-1][idx] is None else history[-1][idx] + text
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
history[-1][idx] = text
|
220 |
+
if isinstance(text, str):
|
221 |
+
logger.info(f"add text: {md2plain(text)}")
|
222 |
+
|
223 |
+
return history, ""
|
224 |
+
|
225 |
+
def add_audio(self, history, audio, role="assistant", append=False):
|
226 |
+
assert role in ["human", "assistant"]
|
227 |
+
result = self.whisper.transcribe(audio)
|
228 |
+
text = result["text"]
|
229 |
+
logger.info(f"add audio: {text}")
|
230 |
+
return self.add_text(history, text, role=role, append=append)
|
231 |
+
|
232 |
+
def plan(self, user_state, input_image, history, history_plan):
|
233 |
+
logger.info(f"Task plan...")
|
234 |
+
if user_state.get("resources", None) is None:
|
235 |
+
user_state["resources"] = OrderedDict()
|
236 |
+
|
237 |
+
request = history[-1][0]
|
238 |
+
user_state["request"] = request
|
239 |
+
if isinstance(request, str) and request.startswith("$"):
|
240 |
+
solution = f'show$("{request[1:]}")'
|
241 |
+
else:
|
242 |
+
solution = self.controller.plan(request, state=user_state)
|
243 |
+
print(f"request: {request}")
|
244 |
+
if solution == self.controller.SHORTCUT:
|
245 |
+
# md_text = "**Using builtin shortcut solution.**"
|
246 |
+
history, _ = self.add_text(
|
247 |
+
history, solution, role="assistant", append=False
|
248 |
+
)
|
249 |
+
user_state["solution"] = solution
|
250 |
+
user_state["history_msgs"] = history
|
251 |
+
yield user_state, input_image, history, [solution]
|
252 |
+
elif isinstance(solution, str) and solution.startswith("show$"):
|
253 |
+
user_state["solution"] = solution
|
254 |
+
yield user_state, input_image, history, solution
|
255 |
+
else:
|
256 |
+
output_text = (
|
257 |
+
"The whole process will take some time, please be patient.<br><br>"
|
258 |
+
)
|
259 |
+
history, _ = self.add_text(
|
260 |
+
history, output_text, role="assistant", append=True
|
261 |
+
)
|
262 |
+
yield user_state, input_image, history, history_plan
|
263 |
+
task_decomposition = next(solution)
|
264 |
+
if task_decomposition in [None, [], ""]:
|
265 |
+
output = "Error: unrecognized resource(s) in task decomposition."
|
266 |
+
task_decomposition = "[]"
|
267 |
+
else:
|
268 |
+
output = task_decomposition
|
269 |
+
|
270 |
+
output = f"**Task Decomposition:**\n{output}"
|
271 |
+
output = plain2md(output)
|
272 |
+
history, _ = self.add_text(history, output, role="assistant", append=True)
|
273 |
+
user_state["task_decomposition"] = json.loads(task_decomposition)
|
274 |
+
yield user_state, input_image, history, history_plan
|
275 |
+
|
276 |
+
history, _ = self.add_text(
|
277 |
+
history,
|
278 |
+
plain2md("\n\n**Thoughs-on-Graph:**\n"),
|
279 |
+
role="assistant",
|
280 |
+
append=True,
|
281 |
+
)
|
282 |
+
yield user_state, input_image, history, history_plan
|
283 |
+
solution_str = next(solution)
|
284 |
+
logger.info(f"Thoughs-on-Graph: \n{solution_str}")
|
285 |
+
if solution_str in [None, [], ""]:
|
286 |
+
output = "Empty solution possibly due to some internal errors."
|
287 |
+
solution_str = "[]"
|
288 |
+
else:
|
289 |
+
output = solution_str
|
290 |
+
|
291 |
+
output_md = plain2md(output)
|
292 |
+
history, _ = self.add_text(
|
293 |
+
history, output_md, role="assistant", append=True
|
294 |
+
)
|
295 |
+
solution = json.loads(solution_str)
|
296 |
+
user_state["solution"] = solution
|
297 |
+
user_state["history_msgs"] = history
|
298 |
+
yield user_state, input_image, history, solution
|
299 |
+
|
300 |
+
def execute(self, user_state, input_image, history, history_plan):
|
301 |
+
resources_state = user_state.get("resources", OrderedDict())
|
302 |
+
solution = user_state.get("solution", None)
|
303 |
+
if not solution:
|
304 |
+
yield user_state, input_image, history, history_plan
|
305 |
+
return
|
306 |
+
logger.info(f"Tool execution...")
|
307 |
+
if isinstance(solution, str) and solution.startswith("show$"):
|
308 |
+
key = solution[7:-2]
|
309 |
+
r_type = resources_state.get(key)
|
310 |
+
if r_type is None:
|
311 |
+
resource = f"{key} not found"
|
312 |
+
resource = container.auto_type("None", r_type, key)
|
313 |
+
history, _ = self.add_text(
|
314 |
+
history, (resource.to_chatbot(),), role="assistant"
|
315 |
+
)
|
316 |
+
user_state["history_msgs"] = history
|
317 |
+
yield user_state, input_image, history, history_plan
|
318 |
+
return
|
319 |
+
elif solution:
|
320 |
+
results = self.controller.execute(solution, state=user_state)
|
321 |
+
if not results:
|
322 |
+
yield user_state, input_image, history, history_plan
|
323 |
+
return
|
324 |
+
|
325 |
+
user_state["outputs"] = []
|
326 |
+
for result_per_step, executed_solutions, wrapped_outputs in results:
|
327 |
+
tool_name = json.dumps(result_per_step[0], ensure_ascii=False)
|
328 |
+
args = json.dumps(result_per_step[1], ensure_ascii=False)
|
329 |
+
if isinstance(result_per_step[2], Exception):
|
330 |
+
ret = f"Internal error: {result_per_step[2]}"
|
331 |
+
else:
|
332 |
+
ret = json.dumps(result_per_step[2], ensure_ascii=False)
|
333 |
+
history, _ = self.add_text(
|
334 |
+
history,
|
335 |
+
f"Call **{tool_name}:**<br> **Args**: {plain2md(args)}<br> **Ret**: {plain2md(ret)}",
|
336 |
+
role="assistant",
|
337 |
+
)
|
338 |
+
user_state["history_msgs"] = history
|
339 |
+
user_state["executed_solutions"] = executed_solutions
|
340 |
+
yield user_state, input_image, history, history_plan
|
341 |
+
for _, output in enumerate(wrapped_outputs):
|
342 |
+
if output is None or output.value is None:
|
343 |
+
continue
|
344 |
+
if isinstance(output, container.File):
|
345 |
+
history, _ = self.add_text(
|
346 |
+
history,
|
347 |
+
f"Here is {output.filename}:",
|
348 |
+
role="assistant",
|
349 |
+
)
|
350 |
+
history, _ = self.add_text(
|
351 |
+
history, (output.to_chatbot(),), role="assistant"
|
352 |
+
)
|
353 |
+
user_state["outputs"].extend(wrapped_outputs)
|
354 |
+
user_state["history_msgs"] = history
|
355 |
+
yield user_state, input_image, history, history_plan
|
356 |
+
|
357 |
+
else:
|
358 |
+
yield user_state, input_image, history, history_plan
|
359 |
+
|
360 |
+
def reply(self, user_state, history):
|
361 |
+
logger.info(f"Make response...")
|
362 |
+
executed_solution = user_state.get("executed_solutions", None)
|
363 |
+
resources_state = user_state.get("resources", OrderedDict())
|
364 |
+
solution = user_state.get("solution", None)
|
365 |
+
memory = convert_dict_to_frame(resources_state)
|
366 |
+
if isinstance(solution, str) and solution.startswith("show$"):
|
367 |
+
return user_state, history, memory
|
368 |
+
|
369 |
+
outputs = user_state.get("outputs", None)
|
370 |
+
response, user_state = self.controller.reply(
|
371 |
+
executed_solution, outputs, user_state
|
372 |
+
)
|
373 |
+
# prompt_mask_out = None
|
374 |
+
for i, output in enumerate(response):
|
375 |
+
if isinstance(output, container.File):
|
376 |
+
history, _ = self.add_text(history, f"Here is [{output.filename}]: ")
|
377 |
+
history, _ = self.add_text(history, (output.to_chatbot(),))
|
378 |
+
elif i == 0:
|
379 |
+
history, _ = self.add_text(history, output.to_chatbot())
|
380 |
+
|
381 |
+
user_state["history_msgs"] = history
|
382 |
+
return user_state, history, memory
|
383 |
+
|
384 |
+
def vote(self, user_state, history, data: gr.LikeData):
|
385 |
+
data_value = data.value
|
386 |
+
if isinstance(data_value, dict):
|
387 |
+
data_value = json.dumps(data_value)
|
388 |
+
|
389 |
+
if data.liked:
|
390 |
+
print("You upvoted this response: ", data_value)
|
391 |
+
logger.info("You upvoted this response: " + data_value)
|
392 |
+
else:
|
393 |
+
print("You downvoted this response: ", data_value)
|
394 |
+
logger.info("You downvoted this response: " + data_value)
|
395 |
+
|
396 |
+
remote_logging(
|
397 |
+
user_state.get("history_msgs", []),
|
398 |
+
user_state.get("task_decomposition", ""),
|
399 |
+
user_state.get("solution", []),
|
400 |
+
data_value,
|
401 |
+
data.liked,
|
402 |
+
)
|
403 |
+
|
404 |
+
msg = f"Thanks for your feedback! You feedback will contribute a lot to improving our ControlLLM."
|
405 |
+
history, _ = self.add_text(history, msg)
|
406 |
+
user_state["history_msgs"] = history
|
407 |
+
return user_state, history
|
408 |
+
|
409 |
+
def save_point(self, user_state, history, data: gr.SelectData):
|
410 |
+
if isinstance(data, gr.LikeData):
|
411 |
+
return self.vote(user_state, history, data)
|
412 |
+
|
413 |
+
if not isinstance(data, gr.SelectData):
|
414 |
+
return user_state, history
|
415 |
+
|
416 |
+
resource_state = user_state.get("resources")
|
417 |
+
input_image = user_state.get("input_image", None)
|
418 |
+
if input_image is None:
|
419 |
+
history, _ = self.add_text(history, "Please upload an image at first.")
|
420 |
+
history, _ = self.add_text(history, plans.BUILTIN_SEG_BY_POINTS, "human")
|
421 |
+
user_state["history_msg"] = history
|
422 |
+
return user_state, history
|
423 |
+
|
424 |
+
resource_state.pop(input_image, None)
|
425 |
+
resource_state[input_image] = "image"
|
426 |
+
|
427 |
+
history = history + [[plans.BUILTIN_SEG_BY_POINTS, None]]
|
428 |
+
points = []
|
429 |
+
if isinstance(points, str):
|
430 |
+
points = json.loads(points)
|
431 |
+
|
432 |
+
points.append(data.index)
|
433 |
+
resource_state[json.dumps(points)] = "prompt_points"
|
434 |
+
user_state["resources"] = resource_state
|
435 |
+
return user_state, history
|
436 |
+
|
437 |
+
|
438 |
+
def on_switch_input(state_input, text, audio, disable=False):
|
439 |
+
if state_input == "audio" or disable:
|
440 |
+
return "text", gr.update(visible=True), gr.update(visible=False)
|
441 |
+
return "audio", gr.update(visible=False), gr.update(visible=True)
|
442 |
+
|
443 |
+
|
444 |
+
def on_mask_submit(history):
|
445 |
+
history = history + [(plans.BUILTIN_SEG_BY_MASK, None)]
|
446 |
+
return history
|
447 |
+
|
448 |
+
|
449 |
+
def app(controller="cllm.agents.tog.Controller", https=False, **kwargs):
|
450 |
+
loop = InteractionLoop(controller=controller)
|
451 |
+
init_state, builtin_resources = loop.init_state()
|
452 |
+
css = """
|
453 |
+
code {
|
454 |
+
font-size: var(--text-sm);
|
455 |
+
white-space: pre-wrap; /* Since CSS 2.1 */
|
456 |
+
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
457 |
+
white-space: -pre-wrap; /* Opera 4-6 */
|
458 |
+
white-space: -o-pre-wrap; /* Opera 7 */
|
459 |
+
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
460 |
+
}
|
461 |
+
"""
|
462 |
+
with gr.Blocks(theme=Seafoam(), css=css) as demo:
|
463 |
+
gr.HTML(
|
464 |
+
"""
|
465 |
+
<div align='center'> <h1>ControlLLM </h1> </div>
|
466 |
+
<p align="center"> A framework for multi-modal interaction which is able to control LLMs over invoking tools more accurately. </p>
|
467 |
+
<p align="center"><a href="https://github.com/OpenGVLab/ControlLLM"><b>GitHub</b></a>
|
468 |
+
<a href="https://arxiv.org/abs/2311.11797"><b>ArXiv</b></a></p>
|
469 |
+
""",
|
470 |
+
)
|
471 |
+
|
472 |
+
state_input = gr.State("text")
|
473 |
+
user_state = gr.State(copy.deepcopy(init_state))
|
474 |
+
with gr.Row():
|
475 |
+
with gr.Column(scale=6):
|
476 |
+
with gr.Tabs():
|
477 |
+
with gr.Tab("Chat"):
|
478 |
+
chatbot = gr.Chatbot(
|
479 |
+
[],
|
480 |
+
elem_id="chatbot",
|
481 |
+
avatar_images=[
|
482 |
+
"assets/human.png",
|
483 |
+
"assets/assistant.png",
|
484 |
+
],
|
485 |
+
show_copy_button=True,
|
486 |
+
height=550,
|
487 |
+
)
|
488 |
+
|
489 |
+
with gr.Row():
|
490 |
+
with gr.Column(scale=12):
|
491 |
+
text = gr.Textbox(
|
492 |
+
show_label=False,
|
493 |
+
placeholder="Enter text and press enter, or upload an image.",
|
494 |
+
container=False,
|
495 |
+
)
|
496 |
+
audio = gr.Audio(
|
497 |
+
sources="microphone", type="filepath", visible=False
|
498 |
+
)
|
499 |
+
with gr.Column(scale=2, min_width=80):
|
500 |
+
submit = gr.Button("Submit", variant="primary")
|
501 |
+
with gr.Column(scale=1, min_width=40):
|
502 |
+
record = gr.Button("🎙️")
|
503 |
+
with gr.Column(scale=1, min_width=40):
|
504 |
+
upload_btn = gr.UploadButton(
|
505 |
+
"📁",
|
506 |
+
file_types=[
|
507 |
+
"image",
|
508 |
+
"video",
|
509 |
+
"audio",
|
510 |
+
".pdf",
|
511 |
+
],
|
512 |
+
)
|
513 |
+
|
514 |
+
gr.Examples(
|
515 |
+
[
|
516 |
+
"Who are you?",
|
517 |
+
"How is about weather in Beijing",
|
518 |
+
"Describe the given image.",
|
519 |
+
"find the woman wearing the red skirt in the image",
|
520 |
+
"Generate a video that shows Pikachu surfing in waves.",
|
521 |
+
"How many horses are there in the image?",
|
522 |
+
"Can you erase the dog in the given image?",
|
523 |
+
"Remove the object based on the given mask.",
|
524 |
+
"Can you make a video of a serene lake with vibrant green grass and trees all around? And then create a webpage using HTML to showcase this video?",
|
525 |
+
"Generate an image that shows a beautiful landscape with a calm lake reflecting the blue sky and white clouds. Then generate a video to introduce this image.",
|
526 |
+
"replace the masked object with a cute yellow dog",
|
527 |
+
"replace the sheep with a cute dog in the image",
|
528 |
+
"Recognize the action in the video",
|
529 |
+
"Generate an image where a astronaut is riding a horse",
|
530 |
+
"Please generate a piece of music from the given image",
|
531 |
+
"Please give me an image that shows an astronaut riding a horse on mars.",
|
532 |
+
"What’s the weather situation in Berlin? Can you generate a new image that represents the weather in there?",
|
533 |
+
"Can you recognize the text from the image and tell me how much is Eggs Florentine?",
|
534 |
+
"Generate a piece of music for this video and dub this video with generated music",
|
535 |
+
"Generate a new image based on depth map from input image",
|
536 |
+
"Remove the cats from the image_1.png, image_2.png, image_3.png",
|
537 |
+
"I need the banana removed from the c4c40e_image.png, 9e867c_image.png, 9e13sc_image.png",
|
538 |
+
"I would be so happy if you could create a new image using the scribble from input image. The new image should be a tropical island with a dog. Write a detailed description of the given image. and highlight the dog in image",
|
539 |
+
"Please generate a piece of music and a new video from the input image",
|
540 |
+
"generate a new image conditioned on the segmentation from input image and the new image shows that a gorgeous lady is dancing",
|
541 |
+
"generate a new image with a different background but maintaining the same composition as input image",
|
542 |
+
"Generate a new image that shows an insect robot preparing a delicious meal. Then give me a video based on new image. Finally, dub the video with suitable background music.",
|
543 |
+
"Translate the text into speech: I have a dream that one day this nation will rise up and live out the true meaning of its creed: We hold these truths to be self-evident that all men are created equal.I have a dream that one day on the red hills of Georgia the sons of former slaves and the sons of former slave owners will be able to sit down together at the table of brotherhood. I have a dream that one day even the state of Mississippi, a state sweltering with the heat of injustice, sweltering with the heat of oppression, will be transformed into an oasis of freedom and justice.",
|
544 |
+
],
|
545 |
+
inputs=[text],
|
546 |
+
)
|
547 |
+
gr.Examples(
|
548 |
+
list(plans.BUILTIN_PLANS.keys()),
|
549 |
+
inputs=[text],
|
550 |
+
label="Builtin Examples",
|
551 |
+
)
|
552 |
+
|
553 |
+
with gr.Column(scale=5):
|
554 |
+
with gr.Tabs():
|
555 |
+
with gr.Tab("Mask Input"):
|
556 |
+
image_mask = gr.components.Image(
|
557 |
+
sources="upload",
|
558 |
+
interactive=True,
|
559 |
+
type="filepath",
|
560 |
+
)
|
561 |
+
# with gr.Row():
|
562 |
+
# mask_submit_btn = gr.Button("Segment", variant="primary")
|
563 |
+
with gr.Row():
|
564 |
+
image_submit_btn = gr.Button("Upload", variant="primary")
|
565 |
+
|
566 |
+
with gr.Tab("Plan"):
|
567 |
+
planbot = gr.JSON(elem_classes="json")
|
568 |
+
|
569 |
+
with gr.Tab("Memory"):
|
570 |
+
memory_table = gr.DataFrame(
|
571 |
+
# value=convert_dict_to_frame(builtin_resources),
|
572 |
+
label="Memory",
|
573 |
+
headers=["Resource", "Type"],
|
574 |
+
row_count=5,
|
575 |
+
wrap=True,
|
576 |
+
)
|
577 |
+
gr.Examples(
|
578 |
+
[
|
579 |
+
osp.join("./assets/resources", item)
|
580 |
+
for item in builtin_resources.keys()
|
581 |
+
if item.endswith(".png")
|
582 |
+
],
|
583 |
+
inputs=[image_mask],
|
584 |
+
label="File Examples",
|
585 |
+
)
|
586 |
+
|
587 |
+
chatbot.like(
|
588 |
+
loop.vote,
|
589 |
+
[
|
590 |
+
user_state,
|
591 |
+
chatbot,
|
592 |
+
],
|
593 |
+
[
|
594 |
+
user_state,
|
595 |
+
chatbot,
|
596 |
+
],
|
597 |
+
)
|
598 |
+
reply_inputs = [user_state, image_mask, chatbot, planbot]
|
599 |
+
reply_outputs = [
|
600 |
+
user_state,
|
601 |
+
# image_mask,
|
602 |
+
chatbot,
|
603 |
+
memory_table,
|
604 |
+
# planbot,
|
605 |
+
]
|
606 |
+
|
607 |
+
add_text = [
|
608 |
+
partial(loop.add_text, role="human"),
|
609 |
+
[chatbot, text],
|
610 |
+
[chatbot, text],
|
611 |
+
]
|
612 |
+
|
613 |
+
text.submit(*add_text).then(loop.plan, reply_inputs, reply_inputs).then(
|
614 |
+
loop.execute, reply_inputs, reply_inputs
|
615 |
+
).then(loop.reply, [user_state, chatbot], reply_outputs)
|
616 |
+
|
617 |
+
add_msg = [
|
618 |
+
partial(loop.add_msg, role="human"),
|
619 |
+
[chatbot, text, audio],
|
620 |
+
[chatbot, text],
|
621 |
+
]
|
622 |
+
|
623 |
+
submit.click(*add_msg).then(
|
624 |
+
partial(on_switch_input, disable=True),
|
625 |
+
[state_input, text, audio],
|
626 |
+
[state_input, text, audio],
|
627 |
+
).then(loop.plan, reply_inputs, reply_inputs).then(
|
628 |
+
loop.execute, reply_inputs, reply_inputs
|
629 |
+
).then(
|
630 |
+
loop.reply, [user_state, chatbot], reply_outputs
|
631 |
+
)
|
632 |
+
|
633 |
+
upload_btn.upload(
|
634 |
+
loop.add_file,
|
635 |
+
inputs=[user_state, chatbot, upload_btn],
|
636 |
+
outputs=[user_state, image_mask, chatbot, memory_table],
|
637 |
+
)
|
638 |
+
record.click(
|
639 |
+
on_switch_input,
|
640 |
+
[state_input, text, audio],
|
641 |
+
[state_input, text, audio],
|
642 |
+
)
|
643 |
+
|
644 |
+
image_mask.select(
|
645 |
+
loop.save_point, [user_state, chatbot], [user_state, chatbot]
|
646 |
+
).then(loop.plan, reply_inputs, reply_inputs).then(
|
647 |
+
loop.execute, reply_inputs, reply_inputs
|
648 |
+
).then(
|
649 |
+
loop.reply, [user_state, chatbot], reply_outputs
|
650 |
+
)
|
651 |
+
|
652 |
+
image_mask.upload(
|
653 |
+
loop.add_file,
|
654 |
+
inputs=[user_state, chatbot, image_mask],
|
655 |
+
outputs=[user_state, image_mask, chatbot, memory_table],
|
656 |
+
)
|
657 |
+
image_submit_btn.click(
|
658 |
+
loop.add_file,
|
659 |
+
inputs=[user_state, chatbot, image_mask],
|
660 |
+
outputs=[user_state, image_mask, chatbot, memory_table],
|
661 |
+
)
|
662 |
+
|
663 |
+
if https:
|
664 |
+
demo.queue().launch(
|
665 |
+
server_name="0.0.0.0",
|
666 |
+
# ssl_certfile="./certificate/cert.pem",
|
667 |
+
# ssl_keyfile="./certificate/key.pem",
|
668 |
+
ssl_verify=False,
|
669 |
+
show_api=False,
|
670 |
+
allowed_paths=[
|
671 |
+
"assets/human.png",
|
672 |
+
"assets/assistant.png",
|
673 |
+
],
|
674 |
+
**kwargs,
|
675 |
+
)
|
676 |
+
else:
|
677 |
+
demo.queue().launch(
|
678 |
+
server_name="0.0.0.0",
|
679 |
+
show_api=False,
|
680 |
+
allowed_paths=[
|
681 |
+
"assets/human.png",
|
682 |
+
"assets/assistant.png",
|
683 |
+
],
|
684 |
+
**kwargs,
|
685 |
+
)
|
686 |
+
|
687 |
+
|
688 |
+
if __name__ == "__main__":
|
689 |
+
os.makedirs(RESOURCE_ROOT, exist_ok=True)
|
690 |
+
app(controller="cllm.agents.tog.Controller", server_port=10024)
|
assets/assistant.png
ADDED
assets/human.png
ADDED
builtin_plan.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"you know what I want": [
|
3 |
+
[
|
4 |
+
{
|
5 |
+
"tool_name": "text_to_image",
|
6 |
+
"inputs": {
|
7 |
+
"text": "a dog"
|
8 |
+
},
|
9 |
+
"outputs": [
|
10 |
+
"image"
|
11 |
+
]
|
12 |
+
}
|
13 |
+
]
|
14 |
+
]
|
15 |
+
}
|
cllm/agents/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import Tool, Action
|
2 |
+
from .container import *
|
cllm/agents/base.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Callable, List
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class Action:
|
11 |
+
"""The action represent an assignment.
|
12 |
+
`output = tool_name(**inputs)`
|
13 |
+
|
14 |
+
Examples:
|
15 |
+
>>> mask = segmentation_by_mask(image=image, prompt_mask=prompt_mask)
|
16 |
+
>>> image = image_inpainting(image=image, mask=mask)
|
17 |
+
"""
|
18 |
+
|
19 |
+
tool_name: str = (None,)
|
20 |
+
inputs: dict = (None,)
|
21 |
+
outputs: List[str] = (None,)
|
22 |
+
|
23 |
+
def __str__(self) -> str:
|
24 |
+
args = ", ".join([f"{k}={v}" for k, v in self.inputs.items()])
|
25 |
+
return "{} = {}(".format(", ".join(self.outputs), self.tool_name) + args + ")"
|
26 |
+
|
27 |
+
def dict(self):
|
28 |
+
args = {str(k): str(v) for k, v in self.inputs.items()}
|
29 |
+
# args = {str(item["name"]): str(item["value"]) for item in self.inputs}
|
30 |
+
rets = [o if isinstance(o, str) else str(o) for o in self.outputs]
|
31 |
+
return {
|
32 |
+
"tool": self.tool_name,
|
33 |
+
"inputs": args,
|
34 |
+
"outputs": rets,
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
class DataType(Enum):
|
39 |
+
TEXT = "text"
|
40 |
+
TAGS = "tags"
|
41 |
+
TITLE = "title"
|
42 |
+
# HTML = "text.html"
|
43 |
+
HTML = "html"
|
44 |
+
LOCATION = "location"
|
45 |
+
WEATHER = "weather"
|
46 |
+
TIME = "time"
|
47 |
+
|
48 |
+
IMAGE = "image"
|
49 |
+
VIDEO = "video"
|
50 |
+
AUDIO = "audio"
|
51 |
+
ANY = "any"
|
52 |
+
NONE = "none"
|
53 |
+
|
54 |
+
SEGMENTATION = "image.segmentation"
|
55 |
+
EDGE = "image.edge"
|
56 |
+
LINE = "image.line"
|
57 |
+
HED = "image.hed"
|
58 |
+
CANNY = "image.canny"
|
59 |
+
SCRIBBLE = "image.scribble"
|
60 |
+
POSE = "image.pose"
|
61 |
+
DEPTH = "image.depth"
|
62 |
+
NORMAL = "image.normal"
|
63 |
+
|
64 |
+
MASK = "image.mask" # SAM mask
|
65 |
+
POINT = "point"
|
66 |
+
BBOX = "bbox" # {'label': 'dog', 'box': [1,2,3,4], 'score': 0.9}
|
67 |
+
CATEGORY = "category"
|
68 |
+
|
69 |
+
LIST = "list"
|
70 |
+
|
71 |
+
def __str__(self):
|
72 |
+
return self.value
|
73 |
+
|
74 |
+
def __eq__(self, other):
|
75 |
+
if isinstance(other, str):
|
76 |
+
return self.value == other
|
77 |
+
elif isinstance(other, self.__class__):
|
78 |
+
return self.value == other.value
|
79 |
+
else:
|
80 |
+
return False
|
81 |
+
|
82 |
+
|
83 |
+
@dataclass
|
84 |
+
class Resource:
|
85 |
+
name: str
|
86 |
+
type: DataType
|
87 |
+
value: None
|
88 |
+
# description: str = None
|
89 |
+
|
90 |
+
def dict(self):
|
91 |
+
return {
|
92 |
+
"name": self.name,
|
93 |
+
"type": str(self.type),
|
94 |
+
"value": str(self.value),
|
95 |
+
# "description": self.description,
|
96 |
+
}
|
97 |
+
|
98 |
+
|
99 |
+
@dataclass
|
100 |
+
class Tool:
|
101 |
+
class Domain(Enum):
|
102 |
+
IMAGE_PERCEPTION = "image-perception"
|
103 |
+
IMAGE_GENERATION = "image-generation"
|
104 |
+
IMAGE_EDITING = "image-editing"
|
105 |
+
IMAGE_PROCESSING = "image-processing"
|
106 |
+
AUDIO_PERCEPTION = "audio-perception"
|
107 |
+
AUDIO_GENERATION = "audio-generation"
|
108 |
+
VIDEO_PERCEPTION = "video-perception"
|
109 |
+
VIDEO_GENERATION = "video-generation"
|
110 |
+
VIDEO_PROCESSING = "video-processing"
|
111 |
+
VIDEO_EDITING = "video-editing"
|
112 |
+
VIDEO_CUTTING = "video-cutting"
|
113 |
+
NATURAL_LANGUAGE_PROCESSING = "natural-language-processing"
|
114 |
+
CODE_GENERATION = "code-generation"
|
115 |
+
VISUAL_QUESTION_ANSWERING = "visual-question-answering"
|
116 |
+
QUESTION_ANSWERING = "question-answering"
|
117 |
+
GENERAL = "general"
|
118 |
+
|
119 |
+
def __str__(self):
|
120 |
+
return self.value
|
121 |
+
|
122 |
+
@dataclass
|
123 |
+
class Argument:
|
124 |
+
name: str
|
125 |
+
type: DataType
|
126 |
+
description: str
|
127 |
+
|
128 |
+
def dict(self):
|
129 |
+
return {
|
130 |
+
"name": self.name,
|
131 |
+
"type": str(self.type),
|
132 |
+
"description": self.description,
|
133 |
+
}
|
134 |
+
|
135 |
+
name: str
|
136 |
+
description: str
|
137 |
+
domain: Domain
|
138 |
+
model: Callable
|
139 |
+
|
140 |
+
usages: List[str] = field(default_factory=lambda: [])
|
141 |
+
args: List[Argument] = field(default_factory=lambda: [])
|
142 |
+
returns: List[Argument] = field(default_factory=lambda: [])
|
143 |
+
|
144 |
+
def dict(self):
|
145 |
+
return {
|
146 |
+
"name": self.name,
|
147 |
+
"description": self.description,
|
148 |
+
"domain": str(self.domain),
|
149 |
+
"args": [a.dict() for a in self.args],
|
150 |
+
"returns": [r.dict() for r in self.returns],
|
151 |
+
}
|
152 |
+
|
153 |
+
|
154 |
+
NON_FILE_TYPES = [
|
155 |
+
DataType.TAGS,
|
156 |
+
DataType.TEXT,
|
157 |
+
DataType.TITLE,
|
158 |
+
DataType.BBOX,
|
159 |
+
DataType.CATEGORY,
|
160 |
+
DataType.LIST,
|
161 |
+
DataType.LOCATION,
|
162 |
+
DataType.POINT,
|
163 |
+
DataType.WEATHER,
|
164 |
+
DataType.TIME,
|
165 |
+
]
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
s = [
|
170 |
+
[Action("a", {"aa": [Path("/a/d/e/t.txt")]}, [Path("/a/aa.txt")])],
|
171 |
+
Action("b", {"bb": "bbb"}, ["bbb"]),
|
172 |
+
]
|
173 |
+
print(json.dumps(s, indent=4, default=lambda o: o.dict()))
|
cllm/agents/builtin/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import plans
|
2 |
+
from .plans import BUILTIN_PLANS, load_builtin_plans
|
3 |
+
from .tools import TOOLS
|
cllm/agents/builtin/plans.py
ADDED
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
from cllm.agents.base import Action
|
7 |
+
|
8 |
+
BUILTIN_SEG_BY_POINTS = "Segment the given image based on the prompt points."
|
9 |
+
BUILTIN_SEG_BY_MASK = "Segment the given image based on the prompt mask."
|
10 |
+
# BUILTIN_REMOVE_BY_MASK = "Remove the object based on the given mask."
|
11 |
+
BUILTIN_IMAGE_TO_EDGE = "Generate the edge from the given image."
|
12 |
+
BUILTIN_GENERATE_SIMILAR_IMAGE = "Generate a new image similar to the input image"
|
13 |
+
# BUILTIN_GENERATE_SIMILAR_IMAGE2 = "Generate a similar image from the given image 2"
|
14 |
+
# BUILTIN_GENERATE_SIMILAR_IMAGE3 = "Image to image. 3"
|
15 |
+
BUILTIN_GENERATE_SIMILAR_IMAGE4 = "Generate a new image similar to image 4"
|
16 |
+
BUILTIN_GENERATE_IMAGE_HED = "Generate a new image based on HED result from input image"
|
17 |
+
BUILTIN_GENERATE_IMAGE_DEPTH = (
|
18 |
+
"Generate a new image based on depth map from input image"
|
19 |
+
)
|
20 |
+
BUILTIN_GENERATE_IMAGE_OCR = "Please extract the text from the image"
|
21 |
+
BUILTIN_TEXT_EDGE_TO_IMAGE = "Generate an image based on the given edge map."
|
22 |
+
BUILTIN_GENERATE_IMAGE = "Generate a new image that shows a woman is skiing"
|
23 |
+
BUILTIN_IMAGE_TO_VIDEO = "Generate a video from the image"
|
24 |
+
BUILTIN_COUNT_OBJECTS = "Provide me with the count of bears in the input image"
|
25 |
+
BUILTIN_VIDEO_TO_WEBPAGE = "Generate a web page for input video"
|
26 |
+
BUILTIN_TEXT_TO_MUSIC = "Please generate a piece of music based on given prompt. Here is the prompt: An 80s driving pop song with heavy drums and synth pads in the background"
|
27 |
+
BUILTIN_IMAGE_ERASING1 = "Erase the wine glass from the photo"
|
28 |
+
BUILTIN_IMAGE_ERASING2 = "Erase the cats in the photo"
|
29 |
+
BUILTIN_IMAGE_CROPPING = "Crop the cats from the photo"
|
30 |
+
BUILTIN_IMAGE_SEG = "give me the mask of elephant."
|
31 |
+
BUILTIN_IMAGE_HIGHLIGHT = "highlight the elephant."
|
32 |
+
BUILTIN_TEXT_SPEECH = "translate text into speech"
|
33 |
+
BUILTIN_DUBBING = "dub this video with the given audio"
|
34 |
+
BUILTIN_COUNT_OBJECTS2 = "Count the horse in the image."
|
35 |
+
BUILTIN_IMAGE_TO_VIDEO2 = "Generate an image that shows a serene and beautiful landscape with a calm lake reflecting the blue sky and white clouds. Then generate a video to introduce this image."
|
36 |
+
BUILTIN_IMAGE_TO_VIDEO3 = "Create a visual and auditory representation of a peaceful and scenic landscape. The image should depict a serene and beautiful landscape with a calm lake reflecting the blue sky. The music should match the image. Finally, combine the image and the music into a video that showcases the beauty of nature."
|
37 |
+
BUILTIN_VIDEO_CLS = "Recognize the action in the video"
|
38 |
+
BUILTIN_VIDEO_CLS = "Recognize the action in the video"
|
39 |
+
BUILTIN_AUDIO_CLS = "Recognize the event in this audio"
|
40 |
+
BUILTIN_IMAGE2MUSIC = "Generate a piece of music for this image"
|
41 |
+
BUILTIN_VIDEO2MUSIC = (
|
42 |
+
"Generate a piece of music for this video and dub the video with generated music"
|
43 |
+
)
|
44 |
+
|
45 |
+
BUILTIN_PLANS = {
|
46 |
+
# BUILTIN_REMOVE_BY_MASK: [
|
47 |
+
# [
|
48 |
+
# Action(
|
49 |
+
# tool_name="image_inpainting",
|
50 |
+
# inputs={"image": "image", "mask": "image.mask"},
|
51 |
+
# outputs=["<GENERATED>-0"],
|
52 |
+
# )
|
53 |
+
# ]
|
54 |
+
# ],
|
55 |
+
BUILTIN_IMAGE_TO_EDGE: [
|
56 |
+
[
|
57 |
+
Action(
|
58 |
+
tool_name="image_to_edge",
|
59 |
+
inputs={"image": "image"},
|
60 |
+
outputs=["<GENERATED>-0"],
|
61 |
+
)
|
62 |
+
]
|
63 |
+
],
|
64 |
+
BUILTIN_TEXT_EDGE_TO_IMAGE: [
|
65 |
+
[
|
66 |
+
Action(
|
67 |
+
tool_name="image_captioning",
|
68 |
+
inputs={"image": "image"},
|
69 |
+
outputs=["<TOOL-GENERATED>-prompt"],
|
70 |
+
),
|
71 |
+
Action(
|
72 |
+
tool_name="edge_text_to_image",
|
73 |
+
inputs={
|
74 |
+
"edge": "image.edge",
|
75 |
+
"text": "<TOOL-GENERATED>-prompt",
|
76 |
+
},
|
77 |
+
outputs=["<GENERATED>-0"],
|
78 |
+
),
|
79 |
+
]
|
80 |
+
],
|
81 |
+
BUILTIN_GENERATE_SIMILAR_IMAGE: [
|
82 |
+
[
|
83 |
+
Action(
|
84 |
+
tool_name="image_to_edge",
|
85 |
+
inputs={"image": "image"},
|
86 |
+
outputs=["<TOOL-GENERATED>-edge"],
|
87 |
+
),
|
88 |
+
Action(
|
89 |
+
tool_name="image_captioning",
|
90 |
+
inputs={"image": "image"},
|
91 |
+
outputs=["<TOOL-GENERATED>-prompt"],
|
92 |
+
),
|
93 |
+
Action(
|
94 |
+
tool_name="edge_text_to_image",
|
95 |
+
inputs={
|
96 |
+
"edge": "<TOOL-GENERATED>-edge",
|
97 |
+
"text": "<TOOL-GENERATED>-prompt",
|
98 |
+
},
|
99 |
+
outputs=["<GENERATED>-0"],
|
100 |
+
),
|
101 |
+
]
|
102 |
+
],
|
103 |
+
# BUILTIN_GENERATE_SIMILAR_IMAGE2: [
|
104 |
+
# [
|
105 |
+
# Action(
|
106 |
+
# tool_name="image_captioning",
|
107 |
+
# inputs={"image": "image"},
|
108 |
+
# outputs=["<TOOL-GENERATED>-prompt"],
|
109 |
+
# ),
|
110 |
+
# Action(
|
111 |
+
# tool_name="text_to_image",
|
112 |
+
# inputs={"text": "<TOOL-GENERATED>-prompt"},
|
113 |
+
# outputs=["<GENERATED>-0"],
|
114 |
+
# ),
|
115 |
+
# ]
|
116 |
+
# ],
|
117 |
+
# BUILTIN_GENERATE_SIMILAR_IMAGE3: [
|
118 |
+
# [
|
119 |
+
# Action(
|
120 |
+
# tool_name="image_to_image",
|
121 |
+
# inputs={"image": "image"},
|
122 |
+
# outputs=["<GENERATED>-0"],
|
123 |
+
# ),
|
124 |
+
# ]
|
125 |
+
# ],
|
126 |
+
BUILTIN_GENERATE_IMAGE_HED: [
|
127 |
+
[
|
128 |
+
Action(
|
129 |
+
tool_name="image_to_hed",
|
130 |
+
inputs={"image": "image"},
|
131 |
+
outputs=["<TOOL-GENERATED>-image_to_hed-hed-0"],
|
132 |
+
),
|
133 |
+
Action(
|
134 |
+
tool_name="hed_text_to_image",
|
135 |
+
inputs={
|
136 |
+
"text": "beautiful mountains and sunset",
|
137 |
+
"hed": "<TOOL-GENERATED>-image_to_hed-hed-0",
|
138 |
+
},
|
139 |
+
outputs=["<GENERATED>-0"],
|
140 |
+
),
|
141 |
+
]
|
142 |
+
],
|
143 |
+
BUILTIN_GENERATE_IMAGE_DEPTH: [
|
144 |
+
[
|
145 |
+
Action(
|
146 |
+
tool_name="image_captioning",
|
147 |
+
inputs={
|
148 |
+
"image": "image",
|
149 |
+
},
|
150 |
+
outputs=["<TOOL-GENERATED>-image_captioning-text-0"],
|
151 |
+
),
|
152 |
+
Action(
|
153 |
+
tool_name="image_to_depth",
|
154 |
+
inputs={"image": "image"},
|
155 |
+
outputs=["<TOOL-GENERATED>-image_to_depth-depth-0"],
|
156 |
+
),
|
157 |
+
Action(
|
158 |
+
tool_name="depth_text_to_image",
|
159 |
+
inputs={
|
160 |
+
"text": "<TOOL-GENERATED>-image_captioning-text-0",
|
161 |
+
"depth": "<TOOL-GENERATED>-image_to_depth-depth-0",
|
162 |
+
},
|
163 |
+
outputs=["<GENERATED>-0"],
|
164 |
+
),
|
165 |
+
]
|
166 |
+
],
|
167 |
+
BUILTIN_GENERATE_IMAGE_OCR: [
|
168 |
+
[
|
169 |
+
Action(
|
170 |
+
tool_name="optical_character_recognition",
|
171 |
+
inputs={"image": "image"},
|
172 |
+
outputs=["<GENERATED>-0"],
|
173 |
+
)
|
174 |
+
]
|
175 |
+
],
|
176 |
+
BUILTIN_COUNT_OBJECTS: [
|
177 |
+
[
|
178 |
+
Action(
|
179 |
+
tool_name="object_detection",
|
180 |
+
inputs={"image": "image"},
|
181 |
+
outputs=["<TOOL-GENERATED>-object_detection-bbox-0"],
|
182 |
+
),
|
183 |
+
Action(
|
184 |
+
tool_name="select_bbox",
|
185 |
+
inputs={
|
186 |
+
"bbox_list": "<TOOL-GENERATED>-object_detection-bbox-0",
|
187 |
+
"condition": "bear",
|
188 |
+
},
|
189 |
+
outputs=["<TOOL-GENERATED>-select_bbox-bbox-0"],
|
190 |
+
),
|
191 |
+
Action(
|
192 |
+
tool_name="count_objects",
|
193 |
+
inputs={"bbox_list": "<TOOL-GENERATED>-select_bbox-bbox-0"},
|
194 |
+
outputs=["<GENERATED>-0"],
|
195 |
+
),
|
196 |
+
],
|
197 |
+
[
|
198 |
+
Action(
|
199 |
+
tool_name="image_question_answering",
|
200 |
+
inputs={
|
201 |
+
"text": "Provide me with the count of bears in the input image",
|
202 |
+
"image": "image",
|
203 |
+
},
|
204 |
+
outputs=["<GENERATED>-1"],
|
205 |
+
)
|
206 |
+
],
|
207 |
+
],
|
208 |
+
BUILTIN_VIDEO_TO_WEBPAGE: [
|
209 |
+
[
|
210 |
+
Action(
|
211 |
+
tool_name="video_captioning",
|
212 |
+
inputs={"video": "video"},
|
213 |
+
outputs=["<TOOL-GENERATED>-text-0"],
|
214 |
+
),
|
215 |
+
Action(
|
216 |
+
tool_name="text_to_music",
|
217 |
+
inputs={"text": "<TOOL-GENERATED>-text-0"},
|
218 |
+
outputs=["<TOOL-GENERATED>-text_to_music-audio-0"],
|
219 |
+
),
|
220 |
+
Action(
|
221 |
+
tool_name="dub_video",
|
222 |
+
inputs={
|
223 |
+
"video": "video",
|
224 |
+
"audio": "<TOOL-GENERATED>-text_to_music-audio-0",
|
225 |
+
},
|
226 |
+
outputs=["<TOOL-GENERATED>-dub_video-video-0"],
|
227 |
+
),
|
228 |
+
Action(
|
229 |
+
tool_name="title_generation",
|
230 |
+
inputs={"text": "<TOOL-GENERATED>-text-0"},
|
231 |
+
outputs=["<TOOL-GENERATED>-text-1"],
|
232 |
+
),
|
233 |
+
Action(
|
234 |
+
tool_name="text_to_tags",
|
235 |
+
inputs={"text": "<TOOL-GENERATED>-text-0"},
|
236 |
+
outputs=["<TOOL-GENERATED>-tags-0"],
|
237 |
+
),
|
238 |
+
Action(
|
239 |
+
tool_name="video_to_webpage",
|
240 |
+
inputs={
|
241 |
+
"video": "<TOOL-GENERATED>-dub_video-video-0",
|
242 |
+
"title": "<TOOL-GENERATED>-text-1",
|
243 |
+
"tags": "<TOOL-GENERATED>-tags-0",
|
244 |
+
"description": "<TOOL-GENERATED>-text-0",
|
245 |
+
},
|
246 |
+
outputs=["<GENERATED>-0"],
|
247 |
+
),
|
248 |
+
]
|
249 |
+
],
|
250 |
+
BUILTIN_TEXT_TO_MUSIC: [
|
251 |
+
[
|
252 |
+
Action(
|
253 |
+
tool_name="text_to_music",
|
254 |
+
inputs={
|
255 |
+
"text": "An 80s driving pop song with heavy drums and synth pads in the background"
|
256 |
+
},
|
257 |
+
outputs=["<GENERATED>-audio-0"],
|
258 |
+
)
|
259 |
+
]
|
260 |
+
],
|
261 |
+
BUILTIN_IMAGE_ERASING1: [
|
262 |
+
[
|
263 |
+
Action(
|
264 |
+
tool_name="image_instance_segmentation",
|
265 |
+
inputs={"image": "image"},
|
266 |
+
outputs=["<TOOL-GENERATED>-image_instance_segmentation-mask-0"],
|
267 |
+
),
|
268 |
+
Action(
|
269 |
+
tool_name="select_mask",
|
270 |
+
inputs={
|
271 |
+
"mask_list": "<TOOL-GENERATED>-image_instance_segmentation-mask-0",
|
272 |
+
"condition": "wine glass",
|
273 |
+
},
|
274 |
+
outputs=["<TOOL-GENERATED>-select_mask-mask-1"],
|
275 |
+
),
|
276 |
+
Action(
|
277 |
+
tool_name="image_inpainting",
|
278 |
+
inputs={
|
279 |
+
"image": "image",
|
280 |
+
"mask": "<TOOL-GENERATED>-select_mask-mask-0",
|
281 |
+
},
|
282 |
+
outputs=["<GENERATED>-0"],
|
283 |
+
),
|
284 |
+
]
|
285 |
+
],
|
286 |
+
BUILTIN_IMAGE_ERASING2: [
|
287 |
+
[
|
288 |
+
Action(
|
289 |
+
tool_name="image_instance_segmentation",
|
290 |
+
inputs={"image": "image"},
|
291 |
+
outputs=["<TOOL-GENERATED>-image_instance_segmentation-mask-0"],
|
292 |
+
),
|
293 |
+
Action(
|
294 |
+
tool_name="select_mask",
|
295 |
+
inputs={
|
296 |
+
"mask_list": "<TOOL-GENERATED>-image_instance_segmentation-mask-0",
|
297 |
+
"condition": "cat",
|
298 |
+
},
|
299 |
+
outputs=["<TOOL-GENERATED>-select_mask-mask-0"],
|
300 |
+
),
|
301 |
+
Action(
|
302 |
+
tool_name="image_inpainting",
|
303 |
+
inputs={
|
304 |
+
"image": "image",
|
305 |
+
"mask": "<TOOL-GENERATED>-select_mask-mask-0",
|
306 |
+
},
|
307 |
+
outputs=["<GENERATED>-0"],
|
308 |
+
),
|
309 |
+
]
|
310 |
+
],
|
311 |
+
BUILTIN_IMAGE_CROPPING: [
|
312 |
+
[
|
313 |
+
Action(
|
314 |
+
tool_name="object_detection",
|
315 |
+
inputs={"image": "image"},
|
316 |
+
outputs=["<TOOL-GENERATED>-object_detection-bbox-0"],
|
317 |
+
),
|
318 |
+
Action(
|
319 |
+
tool_name="select_bbox",
|
320 |
+
inputs={
|
321 |
+
"bbox_list": "<TOOL-GENERATED>-object_detection-bbox-0",
|
322 |
+
"condition": "cat",
|
323 |
+
},
|
324 |
+
outputs=["<TOOL-GENERATED>-select_bbox-bbox-0"],
|
325 |
+
),
|
326 |
+
Action(
|
327 |
+
tool_name="image_cropping",
|
328 |
+
inputs={
|
329 |
+
"image": "image",
|
330 |
+
"object": "<TOOL-GENERATED>-select_bbox-bbox-0",
|
331 |
+
},
|
332 |
+
outputs=["<GENERATED>-0"],
|
333 |
+
),
|
334 |
+
]
|
335 |
+
],
|
336 |
+
BUILTIN_IMAGE_SEG: [
|
337 |
+
[
|
338 |
+
Action(
|
339 |
+
tool_name="image_instance_segmentation",
|
340 |
+
inputs={"image": "image"},
|
341 |
+
outputs=["<TOOL-GENERATED>-image_instance_segmentation-mask-0"],
|
342 |
+
),
|
343 |
+
Action(
|
344 |
+
tool_name="select_mask",
|
345 |
+
inputs={
|
346 |
+
"mask_list": "<TOOL-GENERATED>-image_instance_segmentation-mask-0",
|
347 |
+
"condition": "elephant",
|
348 |
+
},
|
349 |
+
outputs=["<GENERATED>-0"],
|
350 |
+
),
|
351 |
+
]
|
352 |
+
],
|
353 |
+
BUILTIN_IMAGE_HIGHLIGHT: [
|
354 |
+
[
|
355 |
+
Action(
|
356 |
+
tool_name="object_detection",
|
357 |
+
inputs={"image": "image"},
|
358 |
+
outputs=["<TOOL-GENERATED>-object_detection-bbox-0"],
|
359 |
+
),
|
360 |
+
Action(
|
361 |
+
tool_name="select_bbox",
|
362 |
+
inputs={
|
363 |
+
"bbox_list": "<TOOL-GENERATED>-object_detection-bbox-0",
|
364 |
+
"condition": "elephant",
|
365 |
+
},
|
366 |
+
outputs=["<TOOL-GENERATED>-select_bbox-bbox-0"],
|
367 |
+
),
|
368 |
+
Action(
|
369 |
+
tool_name="highlight_object_on_image",
|
370 |
+
inputs={
|
371 |
+
"image": "image",
|
372 |
+
"bbox": "<TOOL-GENERATED>-select_bbox-bbox-0",
|
373 |
+
},
|
374 |
+
outputs=["<GENERATED>-0"],
|
375 |
+
),
|
376 |
+
]
|
377 |
+
],
|
378 |
+
BUILTIN_TEXT_SPEECH: [
|
379 |
+
[
|
380 |
+
Action(
|
381 |
+
tool_name="text_to_speech",
|
382 |
+
inputs={
|
383 |
+
"text": "Hope is the thing with feathers That perches in the soul, And sings the tune without the words, And never stops at all"
|
384 |
+
},
|
385 |
+
outputs=["<GENERATED>-0"],
|
386 |
+
)
|
387 |
+
]
|
388 |
+
],
|
389 |
+
BUILTIN_DUBBING: [
|
390 |
+
[
|
391 |
+
Action(
|
392 |
+
tool_name="dub_video",
|
393 |
+
inputs={"video": "video", "audio": "audio"},
|
394 |
+
outputs=["<GENERATED>-0"],
|
395 |
+
)
|
396 |
+
]
|
397 |
+
],
|
398 |
+
BUILTIN_GENERATE_SIMILAR_IMAGE4: [
|
399 |
+
[
|
400 |
+
Action(
|
401 |
+
tool_name="segment_anything",
|
402 |
+
inputs={"image": "image"},
|
403 |
+
outputs=["<TOOL-GENERATED>-seg"],
|
404 |
+
),
|
405 |
+
Action(
|
406 |
+
tool_name="image_captioning",
|
407 |
+
inputs={"image": "image"},
|
408 |
+
outputs=["<TOOL-GENERATED>-prompt"],
|
409 |
+
),
|
410 |
+
Action(
|
411 |
+
tool_name="segmentation_text_to_image",
|
412 |
+
inputs={
|
413 |
+
"segmentation": "<TOOL-GENERATED>-seg",
|
414 |
+
"text": "<TOOL-GENERATED>-prompt",
|
415 |
+
},
|
416 |
+
outputs=["<GENERATED>-0"],
|
417 |
+
),
|
418 |
+
]
|
419 |
+
],
|
420 |
+
BUILTIN_GENERATE_IMAGE: [
|
421 |
+
[
|
422 |
+
Action(
|
423 |
+
tool_name="text_to_image",
|
424 |
+
inputs={"text": "a woman is skiing"},
|
425 |
+
outputs=["<GENERATED>-0"],
|
426 |
+
)
|
427 |
+
]
|
428 |
+
],
|
429 |
+
BUILTIN_IMAGE_TO_VIDEO: [
|
430 |
+
[
|
431 |
+
Action(
|
432 |
+
tool_name="image_to_video",
|
433 |
+
inputs={"image": "image"},
|
434 |
+
outputs=["<GENERATED>-0"],
|
435 |
+
)
|
436 |
+
]
|
437 |
+
],
|
438 |
+
BUILTIN_COUNT_OBJECTS2: [
|
439 |
+
[
|
440 |
+
Action(
|
441 |
+
tool_name="object_detection",
|
442 |
+
inputs={"image": "image"},
|
443 |
+
outputs=["<TOOL-GENERATED>-object_detection-bbox-0"],
|
444 |
+
),
|
445 |
+
Action(
|
446 |
+
tool_name="select_bbox",
|
447 |
+
inputs={
|
448 |
+
"bbox_list": "<TOOL-GENERATED>-object_detection-bbox-0",
|
449 |
+
"condition": "horse",
|
450 |
+
},
|
451 |
+
outputs=["<TOOL-GENERATED>-select_bbox-bbox-0"],
|
452 |
+
),
|
453 |
+
Action(
|
454 |
+
tool_name="count_objects",
|
455 |
+
inputs={"bbox_list": "<TOOL-GENERATED>-select_bbox-bbox-0"},
|
456 |
+
outputs=["<GENERATED>-0"],
|
457 |
+
),
|
458 |
+
],
|
459 |
+
[
|
460 |
+
Action(
|
461 |
+
tool_name="image_question_answering",
|
462 |
+
inputs={
|
463 |
+
"text": "Provide me with the count of horses in the input image",
|
464 |
+
"image": "image",
|
465 |
+
},
|
466 |
+
outputs=["<GENERATED>-1"],
|
467 |
+
)
|
468 |
+
],
|
469 |
+
],
|
470 |
+
BUILTIN_IMAGE_TO_VIDEO2: [
|
471 |
+
[
|
472 |
+
Action(
|
473 |
+
tool_name="text_to_image",
|
474 |
+
inputs={
|
475 |
+
"text": "A serene and beautiful landscape with a calm lake reflecting the blue sky and white clouds."
|
476 |
+
},
|
477 |
+
outputs=["<GENERATED>-0"],
|
478 |
+
),
|
479 |
+
],
|
480 |
+
[
|
481 |
+
Action(
|
482 |
+
tool_name="image_captioning",
|
483 |
+
inputs={"image": "<GENERATED>-0"},
|
484 |
+
outputs=["<TOOL-GENERATED>-text-0"],
|
485 |
+
),
|
486 |
+
Action(
|
487 |
+
tool_name="text_to_speech",
|
488 |
+
inputs={"text": "<TOOL-GENERATED>-text-0"},
|
489 |
+
outputs=["<TOOL-GENERATED>-text_to_speech-audio-0"],
|
490 |
+
),
|
491 |
+
Action(
|
492 |
+
tool_name="image_audio_to_video",
|
493 |
+
inputs={
|
494 |
+
"image": "<GENERATED>-0",
|
495 |
+
"audio": "<TOOL-GENERATED>-text_to_speech-audio-0",
|
496 |
+
},
|
497 |
+
outputs=["<GENERATED>-1"],
|
498 |
+
),
|
499 |
+
],
|
500 |
+
],
|
501 |
+
BUILTIN_IMAGE_TO_VIDEO3: [
|
502 |
+
[
|
503 |
+
Action(
|
504 |
+
tool_name="text_to_image",
|
505 |
+
inputs={
|
506 |
+
"text": "A serene and beautiful landscape with a calm lake reflecting the blue sky."
|
507 |
+
},
|
508 |
+
outputs=["<GENERATED>-0"],
|
509 |
+
),
|
510 |
+
],
|
511 |
+
[
|
512 |
+
Action(
|
513 |
+
tool_name="image_captioning",
|
514 |
+
inputs={"image": "<GENERATED>-0"},
|
515 |
+
outputs=["<TOOL-GENERATED>-text-0"],
|
516 |
+
),
|
517 |
+
Action(
|
518 |
+
tool_name="text_to_music",
|
519 |
+
inputs={"text": "<TOOL-GENERATED>-text-0"},
|
520 |
+
outputs=["<GENERATED>-1"],
|
521 |
+
),
|
522 |
+
],
|
523 |
+
[
|
524 |
+
Action(
|
525 |
+
tool_name="image_to_video",
|
526 |
+
inputs={
|
527 |
+
"image": "<GENERATED>-0",
|
528 |
+
},
|
529 |
+
outputs=["<TOOL-GENERATED>-image_to_video-video-0"],
|
530 |
+
),
|
531 |
+
Action(
|
532 |
+
tool_name="dub_video",
|
533 |
+
inputs={
|
534 |
+
"video": "<TOOL-GENERATED>-image_to_video-video-0",
|
535 |
+
"audio": "<GENERATED>-1",
|
536 |
+
},
|
537 |
+
outputs=["<GENERATED>-2"],
|
538 |
+
),
|
539 |
+
],
|
540 |
+
],
|
541 |
+
BUILTIN_VIDEO_CLS: [
|
542 |
+
[
|
543 |
+
Action(
|
544 |
+
tool_name="video_classification",
|
545 |
+
inputs={"video": "video"},
|
546 |
+
outputs=["<GENERATED>-0"],
|
547 |
+
)
|
548 |
+
]
|
549 |
+
],
|
550 |
+
BUILTIN_AUDIO_CLS: [
|
551 |
+
[
|
552 |
+
Action(
|
553 |
+
tool_name="audio_classification",
|
554 |
+
inputs={"audio": "audio"},
|
555 |
+
outputs=["<GENERATED>-0"],
|
556 |
+
)
|
557 |
+
]
|
558 |
+
],
|
559 |
+
BUILTIN_IMAGE2MUSIC: [
|
560 |
+
[
|
561 |
+
Action(
|
562 |
+
tool_name="image_captioning",
|
563 |
+
inputs={"image": "image"},
|
564 |
+
outputs=["<TOOL-GENERATED>-text-0"],
|
565 |
+
),
|
566 |
+
Action(
|
567 |
+
tool_name="text_to_music",
|
568 |
+
inputs={"text": "<TOOL-GENERATED>-text-0"},
|
569 |
+
outputs=["<GENERATED>-0"],
|
570 |
+
),
|
571 |
+
]
|
572 |
+
],
|
573 |
+
BUILTIN_VIDEO2MUSIC: [
|
574 |
+
[
|
575 |
+
Action(
|
576 |
+
tool_name="video_captioning",
|
577 |
+
inputs={"video": "video"},
|
578 |
+
outputs=["<TOOL-GENERATED>-text-0"],
|
579 |
+
),
|
580 |
+
Action(
|
581 |
+
tool_name="text_to_music",
|
582 |
+
inputs={"text": "<TOOL-GENERATED>-text-0"},
|
583 |
+
outputs=["<GENERATED>-0"],
|
584 |
+
),
|
585 |
+
],
|
586 |
+
[
|
587 |
+
Action(
|
588 |
+
tool_name="dub_video",
|
589 |
+
inputs={
|
590 |
+
"video": "video",
|
591 |
+
"audio": "<GENERATED>-0",
|
592 |
+
},
|
593 |
+
outputs=["<GENERATED>-1"],
|
594 |
+
),
|
595 |
+
],
|
596 |
+
],
|
597 |
+
BUILTIN_SEG_BY_POINTS: [
|
598 |
+
[
|
599 |
+
Action(
|
600 |
+
tool_name="image_segmentation_by_points",
|
601 |
+
inputs={"image": "image", "prompt_points": "prompt_points"},
|
602 |
+
outputs=["<GENERATED>-0"],
|
603 |
+
)
|
604 |
+
]
|
605 |
+
],
|
606 |
+
# BUILTIN_SEG_BY_MASK: [
|
607 |
+
# [
|
608 |
+
# Action(
|
609 |
+
# tool_name='image_segmentation_by_mask',
|
610 |
+
# inputs={'image': 'image', 'prompt_mask': 'prompt_mask'},
|
611 |
+
# outputs=['<GENERATED>-0'],
|
612 |
+
# )
|
613 |
+
# ]
|
614 |
+
# ],
|
615 |
+
}
|
616 |
+
|
617 |
+
|
618 |
+
def load_builtin_plans(path):
|
619 |
+
import json
|
620 |
+
|
621 |
+
plans = json.load(open(path, "r"))
|
622 |
+
processed_plan = {}
|
623 |
+
for query, actions in plans.items():
|
624 |
+
actions2 = []
|
625 |
+
for ac in actions[0]:
|
626 |
+
actions2.append(
|
627 |
+
Action(
|
628 |
+
tool_name=ac["tool_name"],
|
629 |
+
inputs=ac["inputs"],
|
630 |
+
outputs=ac["outputs"],
|
631 |
+
),
|
632 |
+
)
|
633 |
+
processed_plan[query] = [actions2]
|
634 |
+
return processed_plan
|
cllm/agents/builtin/tools.py
ADDED
@@ -0,0 +1,1512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
from cllm.services.image_editing.api import (
|
6 |
+
inpainting_ldm,
|
7 |
+
inpainting_ldm_general,
|
8 |
+
partial_image_editing,
|
9 |
+
instruct_pix2pix,
|
10 |
+
image_cropping,
|
11 |
+
image_matting,
|
12 |
+
draw_bbox_on_image,
|
13 |
+
)
|
14 |
+
from cllm.services.image_generation.api import (
|
15 |
+
text2image,
|
16 |
+
image2image,
|
17 |
+
cannytext2image,
|
18 |
+
linetext2image,
|
19 |
+
hedtext2image,
|
20 |
+
scribbletext2image,
|
21 |
+
posetext2image,
|
22 |
+
segtext2image,
|
23 |
+
depthtext2image,
|
24 |
+
normaltext2image,
|
25 |
+
)
|
26 |
+
|
27 |
+
from cllm.services.image_processing.api import (
|
28 |
+
image2canny,
|
29 |
+
image2line,
|
30 |
+
image2hed,
|
31 |
+
image2scribble,
|
32 |
+
image2pose,
|
33 |
+
image2depth,
|
34 |
+
image2normal,
|
35 |
+
)
|
36 |
+
from cllm.services.image_perception.api import (
|
37 |
+
object_detection,
|
38 |
+
image_classification,
|
39 |
+
ocr,
|
40 |
+
segment_objects,
|
41 |
+
visual_grounding,
|
42 |
+
image_captioning,
|
43 |
+
segment_all,
|
44 |
+
seg_by_mask,
|
45 |
+
seg_by_points,
|
46 |
+
)
|
47 |
+
from cllm.services.video.api import (
|
48 |
+
video_classification,
|
49 |
+
video_captioning,
|
50 |
+
image_audio_to_video,
|
51 |
+
video_to_webpage,
|
52 |
+
dub_video,
|
53 |
+
image_to_video,
|
54 |
+
text_to_video,
|
55 |
+
)
|
56 |
+
from cllm.services.audio.api import (
|
57 |
+
text_to_music,
|
58 |
+
text_to_speech,
|
59 |
+
audio_classification,
|
60 |
+
)
|
61 |
+
|
62 |
+
# from cllm.services.sam.api import (
|
63 |
+
# segment_by_mask,
|
64 |
+
# segment_by_points,
|
65 |
+
# set_image,
|
66 |
+
# segment_all,
|
67 |
+
# )
|
68 |
+
from cllm.services.general.api import (
|
69 |
+
select,
|
70 |
+
count,
|
71 |
+
remote_logging,
|
72 |
+
)
|
73 |
+
from cllm.services.nlp.api import (
|
74 |
+
text_to_text_generation,
|
75 |
+
title_generation,
|
76 |
+
text_to_tags,
|
77 |
+
question_answering_with_context,
|
78 |
+
openai_chat_model,
|
79 |
+
summarization,
|
80 |
+
extract_location,
|
81 |
+
sentiment_analysis,
|
82 |
+
get_weather,
|
83 |
+
summarize_weather_condition,
|
84 |
+
get_time,
|
85 |
+
)
|
86 |
+
from cllm.services.vqa.api import image_qa
|
87 |
+
from cllm.agents.base import Tool, DataType
|
88 |
+
|
89 |
+
|
90 |
+
QUESTION_ANSWERING_TOOLS = [
|
91 |
+
Tool(
|
92 |
+
name="image_question_answering",
|
93 |
+
description="answers a question about an image",
|
94 |
+
domain=Tool.Domain.VISUAL_QUESTION_ANSWERING,
|
95 |
+
args=[
|
96 |
+
Tool.Argument(
|
97 |
+
name="image",
|
98 |
+
type=DataType.IMAGE,
|
99 |
+
description="the image containing the information",
|
100 |
+
),
|
101 |
+
Tool.Argument(
|
102 |
+
name="text",
|
103 |
+
type=DataType.TEXT,
|
104 |
+
description="the question about the image",
|
105 |
+
),
|
106 |
+
],
|
107 |
+
returns=[
|
108 |
+
Tool.Argument(
|
109 |
+
name="response",
|
110 |
+
type=DataType.TEXT,
|
111 |
+
description="output response",
|
112 |
+
)
|
113 |
+
],
|
114 |
+
model=image_qa,
|
115 |
+
),
|
116 |
+
Tool(
|
117 |
+
name="get_weather",
|
118 |
+
description="Query the weather conditions by given location. For example: what is the weather in Beijing? how cold is in New York? etc.",
|
119 |
+
domain=Tool.Domain.QUESTION_ANSWERING,
|
120 |
+
args=[
|
121 |
+
Tool.Argument(
|
122 |
+
name="location",
|
123 |
+
type=DataType.LOCATION,
|
124 |
+
description="the location where the weather is to be queried",
|
125 |
+
),
|
126 |
+
],
|
127 |
+
returns=[
|
128 |
+
Tool.Argument(
|
129 |
+
name="result",
|
130 |
+
# type=DataType.WEATHER,
|
131 |
+
type=DataType.WEATHER,
|
132 |
+
description="weather information",
|
133 |
+
)
|
134 |
+
],
|
135 |
+
model=get_weather,
|
136 |
+
),
|
137 |
+
Tool(
|
138 |
+
name="get_time",
|
139 |
+
description="get current date",
|
140 |
+
domain=Tool.Domain.QUESTION_ANSWERING,
|
141 |
+
args=[
|
142 |
+
# Tool.Argument(
|
143 |
+
# name="location",
|
144 |
+
# type=DataType.LOCATION,
|
145 |
+
# description="location where the time is to be queried",
|
146 |
+
# ),
|
147 |
+
Tool.Argument(
|
148 |
+
name="text",
|
149 |
+
type=DataType.TEXT,
|
150 |
+
description="input text",
|
151 |
+
),
|
152 |
+
],
|
153 |
+
returns=[
|
154 |
+
Tool.Argument(
|
155 |
+
name="response",
|
156 |
+
type=DataType.TIME,
|
157 |
+
description="output time",
|
158 |
+
)
|
159 |
+
],
|
160 |
+
model=get_time,
|
161 |
+
),
|
162 |
+
# Tool(
|
163 |
+
# name="calculator",
|
164 |
+
# description="It can solve mathematics problems and support various mathematical expressions: from basic arithmetic to more complex expressions.",
|
165 |
+
# domain=Tool.Domain.QUESTION_ANSWERING,
|
166 |
+
# args=[
|
167 |
+
# Tool.Argument(
|
168 |
+
# name="text",
|
169 |
+
# type=DataType.TEXT,
|
170 |
+
# description="input instructions",
|
171 |
+
# ),
|
172 |
+
# ],
|
173 |
+
# returns=[
|
174 |
+
# Tool.Argument(
|
175 |
+
# name="result",
|
176 |
+
# type=DataType.TEXT,
|
177 |
+
# description="result about weather",
|
178 |
+
# )
|
179 |
+
# ],
|
180 |
+
# model=None,
|
181 |
+
# ),
|
182 |
+
]
|
183 |
+
|
184 |
+
IMAGE_CAPTIONING_TOOLS = [
|
185 |
+
Tool(
|
186 |
+
name="image_captioning",
|
187 |
+
description='Generate a caption or description for the image. It can generate a detailed description that can be used for image perception and image generation. For example: a) you can use this tool when you want to know what is it in the image"; and b) when you want to generate a new image similar or resemble to input.png, you can use `image_captioning` to obtain the description about image input.png.',
|
188 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
189 |
+
args=[
|
190 |
+
Tool.Argument(
|
191 |
+
name="image",
|
192 |
+
type=DataType.IMAGE,
|
193 |
+
description="the image to be captioned",
|
194 |
+
),
|
195 |
+
],
|
196 |
+
returns=[
|
197 |
+
Tool.Argument(
|
198 |
+
name="text",
|
199 |
+
type=DataType.TEXT,
|
200 |
+
description="the description for the input image",
|
201 |
+
)
|
202 |
+
],
|
203 |
+
model=image_captioning,
|
204 |
+
),
|
205 |
+
]
|
206 |
+
|
207 |
+
IMAGE_EDITING_TOOLS = [
|
208 |
+
Tool(
|
209 |
+
name="partial_image_editing",
|
210 |
+
description="Given the mask denoting the region to edit, Edit the given image at local region. Useful when you want to replace an object via a mask image. "
|
211 |
+
"like: replace the masked object with a dog. ",
|
212 |
+
domain=Tool.Domain.IMAGE_EDITING,
|
213 |
+
args=[
|
214 |
+
Tool.Argument(
|
215 |
+
name="image",
|
216 |
+
type=DataType.IMAGE,
|
217 |
+
description="the image to be edited",
|
218 |
+
),
|
219 |
+
Tool.Argument(
|
220 |
+
name="mask",
|
221 |
+
type=DataType.MASK,
|
222 |
+
description="the mask image representing the editing position",
|
223 |
+
),
|
224 |
+
Tool.Argument(
|
225 |
+
name="prompt",
|
226 |
+
type=DataType.TEXT,
|
227 |
+
description="the prompt specified the edition",
|
228 |
+
),
|
229 |
+
],
|
230 |
+
returns=[
|
231 |
+
Tool.Argument(
|
232 |
+
name="image",
|
233 |
+
type=DataType.IMAGE,
|
234 |
+
description="the edited image",
|
235 |
+
)
|
236 |
+
],
|
237 |
+
model=partial_image_editing,
|
238 |
+
),
|
239 |
+
Tool(
|
240 |
+
name="text_image_editing",
|
241 |
+
description="Edit the given image based on the text prompt.",
|
242 |
+
domain=Tool.Domain.IMAGE_EDITING,
|
243 |
+
args=[
|
244 |
+
Tool.Argument(
|
245 |
+
name="image",
|
246 |
+
type=DataType.IMAGE,
|
247 |
+
description="the image to be edited",
|
248 |
+
),
|
249 |
+
Tool.Argument(
|
250 |
+
name="text",
|
251 |
+
type=DataType.TEXT,
|
252 |
+
description="the prompt specified the edition",
|
253 |
+
),
|
254 |
+
],
|
255 |
+
returns=[
|
256 |
+
Tool.Argument(
|
257 |
+
name="image",
|
258 |
+
type=DataType.IMAGE,
|
259 |
+
description="the edited image",
|
260 |
+
)
|
261 |
+
],
|
262 |
+
model=instruct_pix2pix,
|
263 |
+
),
|
264 |
+
Tool(
|
265 |
+
name="image_inpainting",
|
266 |
+
description="inpaint the region of the image based on the given mask. For example: remove the dog in the image, erase the spoon in given image, etc.",
|
267 |
+
domain=Tool.Domain.IMAGE_EDITING,
|
268 |
+
usages=["remove some objects"],
|
269 |
+
args=[
|
270 |
+
Tool.Argument(
|
271 |
+
name="image",
|
272 |
+
type=DataType.IMAGE,
|
273 |
+
description="the image to be inpainted",
|
274 |
+
),
|
275 |
+
Tool.Argument(
|
276 |
+
name="mask",
|
277 |
+
type=DataType.MASK,
|
278 |
+
description="the segmentation mask for the inpainting region",
|
279 |
+
),
|
280 |
+
],
|
281 |
+
returns=[
|
282 |
+
Tool.Argument(
|
283 |
+
name="image",
|
284 |
+
type=DataType.IMAGE,
|
285 |
+
description="the processed image",
|
286 |
+
)
|
287 |
+
],
|
288 |
+
model=inpainting_ldm_general,
|
289 |
+
),
|
290 |
+
Tool(
|
291 |
+
name="highlight_object_on_image",
|
292 |
+
description="This tool is usually used after `object_detection` `visual_grounding` and `select_bbox`. Useful when you want to: 1) highlight the region of interest on the image; 2) know where the object is. For example: highlight the elephant from image, locate the dog in the image, find the spoon in given image, detect if the object is present in the image, etc.",
|
293 |
+
domain=Tool.Domain.IMAGE_EDITING,
|
294 |
+
usages=["highlight the region of interest on the image"],
|
295 |
+
args=[
|
296 |
+
Tool.Argument(
|
297 |
+
name="image",
|
298 |
+
type=DataType.IMAGE,
|
299 |
+
description="the image to be processed",
|
300 |
+
),
|
301 |
+
Tool.Argument(
|
302 |
+
name="bbox",
|
303 |
+
type=DataType.BBOX,
|
304 |
+
description="the bounding boxes that need to be drawn on the image",
|
305 |
+
),
|
306 |
+
],
|
307 |
+
returns=[
|
308 |
+
Tool.Argument(
|
309 |
+
name="result",
|
310 |
+
type=DataType.IMAGE,
|
311 |
+
description="the new image on which the tool highlight the the region of interest by bounding boxes",
|
312 |
+
)
|
313 |
+
],
|
314 |
+
model=draw_bbox_on_image,
|
315 |
+
),
|
316 |
+
Tool(
|
317 |
+
name="image_cropping",
|
318 |
+
description="Crop the image based on the given bounding box. Useful when you want to crop the dog in the image, crop the spoon in given image, etc.",
|
319 |
+
domain=Tool.Domain.IMAGE_EDITING,
|
320 |
+
args=[
|
321 |
+
Tool.Argument(
|
322 |
+
name="image",
|
323 |
+
type=DataType.IMAGE,
|
324 |
+
description="the image to be processed",
|
325 |
+
),
|
326 |
+
Tool.Argument(
|
327 |
+
name="object",
|
328 |
+
type=DataType.BBOX,
|
329 |
+
description="the detected object",
|
330 |
+
),
|
331 |
+
],
|
332 |
+
returns=[
|
333 |
+
Tool.Argument(
|
334 |
+
name="image",
|
335 |
+
type=DataType.IMAGE,
|
336 |
+
description="the cropped image",
|
337 |
+
)
|
338 |
+
],
|
339 |
+
model=image_cropping,
|
340 |
+
),
|
341 |
+
# Tool(
|
342 |
+
# name="mask_image",
|
343 |
+
# description="Mask the background from the image based on the given mask. For example: mask anything except the dog in the image, extract the spoon from given image without any inpainting, etc.",
|
344 |
+
# domain=Tool.Domain.IMAGE_EDITING,
|
345 |
+
# args=[
|
346 |
+
# Tool.Argument(
|
347 |
+
# name="image",
|
348 |
+
# type=DataType.IMAGE,
|
349 |
+
# description="the image to be processed",
|
350 |
+
# ),
|
351 |
+
# Tool.Argument(
|
352 |
+
# name="mask",
|
353 |
+
# type=DataType.MASK,
|
354 |
+
# description="the mask of the matted region",
|
355 |
+
# ),
|
356 |
+
# ],
|
357 |
+
# returns=[
|
358 |
+
# Tool.Argument(
|
359 |
+
# name="image",
|
360 |
+
# type=DataType.IMAGE,
|
361 |
+
# description="the matted image",
|
362 |
+
# )
|
363 |
+
# ],
|
364 |
+
# model=image_matting,
|
365 |
+
# ),
|
366 |
+
]
|
367 |
+
|
368 |
+
IMAGE_GENERATION_TOOLS = [
|
369 |
+
Tool(
|
370 |
+
name="text_to_image",
|
371 |
+
description="generate an image based on the given description.",
|
372 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
373 |
+
args=[
|
374 |
+
Tool.Argument(
|
375 |
+
name="text",
|
376 |
+
type=DataType.TEXT,
|
377 |
+
description="the text describing the image",
|
378 |
+
),
|
379 |
+
],
|
380 |
+
returns=[
|
381 |
+
Tool.Argument(
|
382 |
+
name="image",
|
383 |
+
type=DataType.IMAGE,
|
384 |
+
description="the generated image",
|
385 |
+
)
|
386 |
+
],
|
387 |
+
model=text2image,
|
388 |
+
),
|
389 |
+
Tool(
|
390 |
+
name="image_to_image",
|
391 |
+
description="generate an new image based on the given image.",
|
392 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
393 |
+
args=[
|
394 |
+
Tool.Argument(
|
395 |
+
name="image",
|
396 |
+
type=DataType.IMAGE,
|
397 |
+
description="the given image",
|
398 |
+
),
|
399 |
+
],
|
400 |
+
returns=[
|
401 |
+
Tool.Argument(
|
402 |
+
name="image",
|
403 |
+
type=DataType.IMAGE,
|
404 |
+
description="the generated image",
|
405 |
+
)
|
406 |
+
],
|
407 |
+
model=image2image,
|
408 |
+
),
|
409 |
+
Tool(
|
410 |
+
name="line_text_to_image",
|
411 |
+
description="generate an image based on the given description and line map.",
|
412 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
413 |
+
args=[
|
414 |
+
Tool.Argument(
|
415 |
+
name="text",
|
416 |
+
type=DataType.TEXT,
|
417 |
+
description="the text describing the image",
|
418 |
+
),
|
419 |
+
Tool.Argument(
|
420 |
+
name="line",
|
421 |
+
type=DataType.LINE,
|
422 |
+
description="the line map outlining the line of the image",
|
423 |
+
),
|
424 |
+
],
|
425 |
+
returns=[
|
426 |
+
Tool.Argument(
|
427 |
+
name="image",
|
428 |
+
type=DataType.IMAGE,
|
429 |
+
description="the generated image",
|
430 |
+
)
|
431 |
+
],
|
432 |
+
model=linetext2image,
|
433 |
+
),
|
434 |
+
Tool(
|
435 |
+
name="hed_text_to_image",
|
436 |
+
description="generate an image based on the given description and HED map (holistically-nested edge detection).",
|
437 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
438 |
+
args=[
|
439 |
+
Tool.Argument(
|
440 |
+
name="text",
|
441 |
+
type=DataType.TEXT,
|
442 |
+
description="the text describing the image",
|
443 |
+
),
|
444 |
+
Tool.Argument(
|
445 |
+
name="hed",
|
446 |
+
type=DataType.HED,
|
447 |
+
description="the HED map outlining the edge of the image",
|
448 |
+
),
|
449 |
+
],
|
450 |
+
returns=[
|
451 |
+
Tool.Argument(
|
452 |
+
name="image",
|
453 |
+
type=DataType.IMAGE,
|
454 |
+
description="the generated image",
|
455 |
+
)
|
456 |
+
],
|
457 |
+
model=hedtext2image,
|
458 |
+
),
|
459 |
+
Tool(
|
460 |
+
name="scribble_text_to_image",
|
461 |
+
description="generate an image based on the given description and the scribble.",
|
462 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
463 |
+
args=[
|
464 |
+
Tool.Argument(
|
465 |
+
name="text",
|
466 |
+
type=DataType.TEXT,
|
467 |
+
description="the text describing the image",
|
468 |
+
),
|
469 |
+
Tool.Argument(
|
470 |
+
name="scribble",
|
471 |
+
type=DataType.SCRIBBLE,
|
472 |
+
description="the scribble outlining the image",
|
473 |
+
),
|
474 |
+
],
|
475 |
+
returns=[
|
476 |
+
Tool.Argument(
|
477 |
+
name="image",
|
478 |
+
type=DataType.IMAGE,
|
479 |
+
description="the generated image",
|
480 |
+
)
|
481 |
+
],
|
482 |
+
model=scribbletext2image,
|
483 |
+
),
|
484 |
+
Tool(
|
485 |
+
name="pose_text_to_image",
|
486 |
+
description="generate an image based on the given description and the pose.",
|
487 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
488 |
+
args=[
|
489 |
+
Tool.Argument(
|
490 |
+
name="text",
|
491 |
+
type=DataType.TEXT,
|
492 |
+
description="the text describing the image",
|
493 |
+
),
|
494 |
+
Tool.Argument(
|
495 |
+
name="pose",
|
496 |
+
type=DataType.POSE,
|
497 |
+
description="the pose of the human in the image",
|
498 |
+
),
|
499 |
+
],
|
500 |
+
returns=[
|
501 |
+
Tool.Argument(
|
502 |
+
name="image",
|
503 |
+
type=DataType.IMAGE,
|
504 |
+
description="the generated image",
|
505 |
+
)
|
506 |
+
],
|
507 |
+
model=posetext2image,
|
508 |
+
),
|
509 |
+
Tool(
|
510 |
+
name="segmentation_text_to_image",
|
511 |
+
description="generate an image based on the given description and segmentation mask.",
|
512 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
513 |
+
args=[
|
514 |
+
Tool.Argument(
|
515 |
+
name="text",
|
516 |
+
type=DataType.TEXT,
|
517 |
+
description="the text describing the image",
|
518 |
+
),
|
519 |
+
Tool.Argument(
|
520 |
+
name="segmentation",
|
521 |
+
type=DataType.SEGMENTATION,
|
522 |
+
description="the segmentation mask describing the structure of the image",
|
523 |
+
),
|
524 |
+
],
|
525 |
+
returns=[
|
526 |
+
Tool.Argument(
|
527 |
+
name="image",
|
528 |
+
type=DataType.IMAGE,
|
529 |
+
description="the generated image",
|
530 |
+
)
|
531 |
+
],
|
532 |
+
model=segtext2image,
|
533 |
+
),
|
534 |
+
Tool(
|
535 |
+
name="edge_text_to_image",
|
536 |
+
description="generate an image based on the given description and edge map.",
|
537 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
538 |
+
args=[
|
539 |
+
Tool.Argument(
|
540 |
+
name="text",
|
541 |
+
type=DataType.TEXT,
|
542 |
+
description="the text describing the image",
|
543 |
+
),
|
544 |
+
Tool.Argument(
|
545 |
+
name="edge",
|
546 |
+
type=DataType.EDGE,
|
547 |
+
description="the edge map describing the structure of the image",
|
548 |
+
),
|
549 |
+
],
|
550 |
+
returns=[
|
551 |
+
Tool.Argument(
|
552 |
+
name="image",
|
553 |
+
type=DataType.IMAGE,
|
554 |
+
description="the generated image",
|
555 |
+
)
|
556 |
+
],
|
557 |
+
model=cannytext2image,
|
558 |
+
),
|
559 |
+
Tool(
|
560 |
+
name="depth_text_to_image",
|
561 |
+
description="generate an image based on the given description and depth map.",
|
562 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
563 |
+
args=[
|
564 |
+
Tool.Argument(
|
565 |
+
name="text",
|
566 |
+
type=DataType.TEXT,
|
567 |
+
description="the text describing the image",
|
568 |
+
),
|
569 |
+
Tool.Argument(
|
570 |
+
name="depth",
|
571 |
+
type=DataType.DEPTH,
|
572 |
+
description="the depth map describing the structure of the image",
|
573 |
+
),
|
574 |
+
],
|
575 |
+
returns=[
|
576 |
+
Tool.Argument(
|
577 |
+
name="image",
|
578 |
+
type=DataType.IMAGE,
|
579 |
+
description="the generated image",
|
580 |
+
)
|
581 |
+
],
|
582 |
+
model=depthtext2image,
|
583 |
+
),
|
584 |
+
Tool(
|
585 |
+
name="normal_text_to_image",
|
586 |
+
description="generate an image based on the given description and normal map.",
|
587 |
+
domain=Tool.Domain.IMAGE_GENERATION,
|
588 |
+
args=[
|
589 |
+
Tool.Argument(
|
590 |
+
name="text",
|
591 |
+
type=DataType.TEXT,
|
592 |
+
description="the text describing the image",
|
593 |
+
),
|
594 |
+
Tool.Argument(
|
595 |
+
name="normal",
|
596 |
+
type=DataType.NORMAL,
|
597 |
+
description="the normal map describing the structure of the image",
|
598 |
+
),
|
599 |
+
],
|
600 |
+
returns=[
|
601 |
+
Tool.Argument(
|
602 |
+
name="image",
|
603 |
+
type=DataType.IMAGE,
|
604 |
+
description="the generated image",
|
605 |
+
)
|
606 |
+
],
|
607 |
+
model=normaltext2image,
|
608 |
+
),
|
609 |
+
]
|
610 |
+
|
611 |
+
IMAGE_TRANSFORM_TOOLS = [
|
612 |
+
Tool(
|
613 |
+
name="image_to_edge",
|
614 |
+
description="get the edge map of the image.",
|
615 |
+
domain=Tool.Domain.IMAGE_PROCESSING,
|
616 |
+
args=[
|
617 |
+
Tool.Argument(
|
618 |
+
name="image",
|
619 |
+
type=DataType.IMAGE,
|
620 |
+
description="the image to be processed",
|
621 |
+
),
|
622 |
+
],
|
623 |
+
returns=[
|
624 |
+
Tool.Argument(
|
625 |
+
name="edge",
|
626 |
+
type=DataType.EDGE,
|
627 |
+
description="the edge map of the image",
|
628 |
+
)
|
629 |
+
],
|
630 |
+
model=image2canny,
|
631 |
+
),
|
632 |
+
Tool(
|
633 |
+
name="image_to_line",
|
634 |
+
description="get the line map of the image.",
|
635 |
+
domain=Tool.Domain.IMAGE_PROCESSING,
|
636 |
+
args=[
|
637 |
+
Tool.Argument(
|
638 |
+
name="image",
|
639 |
+
type=DataType.IMAGE,
|
640 |
+
description="the image to be processed",
|
641 |
+
),
|
642 |
+
],
|
643 |
+
returns=[
|
644 |
+
Tool.Argument(
|
645 |
+
name="line",
|
646 |
+
type=DataType.LINE,
|
647 |
+
description="the line map of the image",
|
648 |
+
)
|
649 |
+
],
|
650 |
+
model=image2line,
|
651 |
+
),
|
652 |
+
Tool(
|
653 |
+
name="image_to_hed",
|
654 |
+
description="get the HED map of the image.",
|
655 |
+
domain=Tool.Domain.IMAGE_PROCESSING,
|
656 |
+
args=[
|
657 |
+
Tool.Argument(
|
658 |
+
name="image",
|
659 |
+
type=DataType.IMAGE,
|
660 |
+
description="the image to be processed",
|
661 |
+
),
|
662 |
+
],
|
663 |
+
returns=[
|
664 |
+
Tool.Argument(
|
665 |
+
name="hed",
|
666 |
+
type=DataType.HED,
|
667 |
+
description="the hed map of the image",
|
668 |
+
)
|
669 |
+
],
|
670 |
+
model=image2hed,
|
671 |
+
),
|
672 |
+
Tool(
|
673 |
+
name="image_to_scribble",
|
674 |
+
description="get the scribble of the image.",
|
675 |
+
domain=Tool.Domain.IMAGE_PROCESSING,
|
676 |
+
args=[
|
677 |
+
Tool.Argument(
|
678 |
+
name="image",
|
679 |
+
type=DataType.IMAGE,
|
680 |
+
description="the image to be processed",
|
681 |
+
),
|
682 |
+
],
|
683 |
+
returns=[
|
684 |
+
Tool.Argument(
|
685 |
+
name="scribble",
|
686 |
+
type=DataType.SCRIBBLE,
|
687 |
+
description="the scribble of the image",
|
688 |
+
)
|
689 |
+
],
|
690 |
+
model=image2scribble,
|
691 |
+
),
|
692 |
+
Tool(
|
693 |
+
name="image_to_pose",
|
694 |
+
description="Get the pose of the image. It is usually used in image generation conditioned on pose map from input image.",
|
695 |
+
domain=Tool.Domain.IMAGE_PROCESSING,
|
696 |
+
args=[
|
697 |
+
Tool.Argument(
|
698 |
+
name="image",
|
699 |
+
type=DataType.IMAGE,
|
700 |
+
description="the image to be processed",
|
701 |
+
),
|
702 |
+
],
|
703 |
+
returns=[
|
704 |
+
Tool.Argument(
|
705 |
+
name="pose",
|
706 |
+
type=DataType.POSE,
|
707 |
+
description="the pose of the image",
|
708 |
+
)
|
709 |
+
],
|
710 |
+
model=image2pose,
|
711 |
+
),
|
712 |
+
Tool(
|
713 |
+
name="image_to_depth",
|
714 |
+
description="get the depth map of the image.",
|
715 |
+
domain=Tool.Domain.IMAGE_PROCESSING,
|
716 |
+
args=[
|
717 |
+
Tool.Argument(
|
718 |
+
name="image",
|
719 |
+
type=DataType.IMAGE,
|
720 |
+
description="the image to be processed",
|
721 |
+
),
|
722 |
+
],
|
723 |
+
returns=[
|
724 |
+
Tool.Argument(
|
725 |
+
name="depth",
|
726 |
+
type=DataType.DEPTH,
|
727 |
+
description="the depth map",
|
728 |
+
)
|
729 |
+
],
|
730 |
+
model=image2depth,
|
731 |
+
),
|
732 |
+
Tool(
|
733 |
+
name="image_to_normal",
|
734 |
+
description="get the normal map of the image.",
|
735 |
+
domain=Tool.Domain.IMAGE_PROCESSING,
|
736 |
+
args=[
|
737 |
+
Tool.Argument(
|
738 |
+
name="image",
|
739 |
+
type=DataType.IMAGE,
|
740 |
+
description="the image to be processed",
|
741 |
+
),
|
742 |
+
],
|
743 |
+
returns=[
|
744 |
+
Tool.Argument(
|
745 |
+
name="normal",
|
746 |
+
type=DataType.NORMAL,
|
747 |
+
description="the normal map",
|
748 |
+
)
|
749 |
+
],
|
750 |
+
model=image2normal,
|
751 |
+
),
|
752 |
+
]
|
753 |
+
|
754 |
+
IMAGE_PERCEPTION_TOOLS = [
|
755 |
+
Tool(
|
756 |
+
name="object_detection",
|
757 |
+
description="detect all the objects in the image.",
|
758 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
759 |
+
args=[
|
760 |
+
Tool.Argument(
|
761 |
+
name="image",
|
762 |
+
type=DataType.IMAGE,
|
763 |
+
description="the image that contains the objects",
|
764 |
+
),
|
765 |
+
],
|
766 |
+
returns=[
|
767 |
+
Tool.Argument(
|
768 |
+
name="object",
|
769 |
+
type=DataType.BBOX,
|
770 |
+
description="the detected objects in json format. "
|
771 |
+
"example output: [\{'score': 0.9994931221008301, 'label': 'dog', 'box': \{'xmin': 466, 'ymin': 301, 'xmax': 1045, 'ymax': 583\}\}]",
|
772 |
+
)
|
773 |
+
],
|
774 |
+
model=object_detection,
|
775 |
+
),
|
776 |
+
Tool(
|
777 |
+
name="image_classification",
|
778 |
+
description="classify the objects in the image.",
|
779 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
780 |
+
usages=["ask about the class of the image"],
|
781 |
+
args=[
|
782 |
+
Tool.Argument(
|
783 |
+
name="image",
|
784 |
+
type=DataType.IMAGE,
|
785 |
+
description="the image that contains the objects",
|
786 |
+
),
|
787 |
+
],
|
788 |
+
returns=[
|
789 |
+
Tool.Argument(
|
790 |
+
name="category",
|
791 |
+
type=DataType.CATEGORY,
|
792 |
+
description="the categories in json format. "
|
793 |
+
"example output: [\{'score': 0.9, 'label': 'dog'\}]",
|
794 |
+
)
|
795 |
+
],
|
796 |
+
model=image_classification,
|
797 |
+
),
|
798 |
+
Tool(
|
799 |
+
name="video_classification",
|
800 |
+
description="Classify the video and detect the actions in the video.",
|
801 |
+
domain=Tool.Domain.VIDEO_PERCEPTION,
|
802 |
+
usages=["ask about the class of the video"],
|
803 |
+
args=[
|
804 |
+
Tool.Argument(
|
805 |
+
name="video",
|
806 |
+
type=DataType.VIDEO,
|
807 |
+
description="the given video",
|
808 |
+
),
|
809 |
+
],
|
810 |
+
returns=[
|
811 |
+
Tool.Argument(
|
812 |
+
name="category",
|
813 |
+
type=DataType.CATEGORY,
|
814 |
+
description="the categories in json format. "
|
815 |
+
"example output: [\{'score': 0.9, 'label': 'Playing basketball'\}]",
|
816 |
+
)
|
817 |
+
],
|
818 |
+
model=video_classification,
|
819 |
+
),
|
820 |
+
Tool(
|
821 |
+
name="image_instance_segmentation",
|
822 |
+
description="segment the common objects in the given image.",
|
823 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
824 |
+
args=[
|
825 |
+
Tool.Argument(
|
826 |
+
name="image",
|
827 |
+
type=DataType.IMAGE,
|
828 |
+
description="the image that need to be segmented",
|
829 |
+
),
|
830 |
+
],
|
831 |
+
returns=[
|
832 |
+
Tool.Argument(
|
833 |
+
name="mask", type=DataType.MASK, description="the output mask"
|
834 |
+
)
|
835 |
+
],
|
836 |
+
model=segment_objects,
|
837 |
+
),
|
838 |
+
Tool(
|
839 |
+
name="image_segmentation_by_mask",
|
840 |
+
description="segment the given image with the prompt mask.",
|
841 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
842 |
+
args=[
|
843 |
+
Tool.Argument(
|
844 |
+
name="image",
|
845 |
+
type=DataType.IMAGE,
|
846 |
+
description="the image that need to be segmented",
|
847 |
+
),
|
848 |
+
Tool.Argument(
|
849 |
+
name="prompt_mask",
|
850 |
+
type=DataType.MASK,
|
851 |
+
description="the prompt mask that guides the segmentation",
|
852 |
+
),
|
853 |
+
],
|
854 |
+
returns=[
|
855 |
+
Tool.Argument(
|
856 |
+
name="mask", type=DataType.MASK, description="the output mask"
|
857 |
+
)
|
858 |
+
],
|
859 |
+
model=seg_by_mask,
|
860 |
+
),
|
861 |
+
Tool(
|
862 |
+
name="image_segmentation_by_points",
|
863 |
+
description="segment the given image with the prompt points.",
|
864 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
865 |
+
args=[
|
866 |
+
Tool.Argument(
|
867 |
+
name="image",
|
868 |
+
type=DataType.IMAGE,
|
869 |
+
description="the image that need to be segmented",
|
870 |
+
),
|
871 |
+
Tool.Argument(
|
872 |
+
name="prompt_points",
|
873 |
+
type=DataType.POINT,
|
874 |
+
description="the prompt points that guides the segmentation",
|
875 |
+
),
|
876 |
+
],
|
877 |
+
returns=[
|
878 |
+
Tool.Argument(
|
879 |
+
name="mask", type=DataType.MASK, description="the output mask"
|
880 |
+
)
|
881 |
+
],
|
882 |
+
model=seg_by_points,
|
883 |
+
),
|
884 |
+
Tool(
|
885 |
+
name="segment_anything",
|
886 |
+
description="Segment the given image without other inputs. This tool return the segmentation map for input image. The segmentation can be used to generate a new image.",
|
887 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
888 |
+
args=[
|
889 |
+
Tool.Argument(
|
890 |
+
name="image",
|
891 |
+
type=DataType.IMAGE,
|
892 |
+
description="the image that need to be segmented",
|
893 |
+
),
|
894 |
+
],
|
895 |
+
returns=[
|
896 |
+
Tool.Argument(
|
897 |
+
name="segmentation",
|
898 |
+
type=DataType.SEGMENTATION,
|
899 |
+
description="the output segmentation",
|
900 |
+
)
|
901 |
+
],
|
902 |
+
model=segment_all,
|
903 |
+
),
|
904 |
+
Tool(
|
905 |
+
name="visual_grounding",
|
906 |
+
description="Visual Grounding (VG) aims to locate the most relevant object or region in an image, based on a natural language query. The query can be a phrase, a sentence or even a multi-round dialogue.",
|
907 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
908 |
+
args=[
|
909 |
+
Tool.Argument(
|
910 |
+
name="image",
|
911 |
+
type=DataType.IMAGE,
|
912 |
+
description="the image that need to be processed",
|
913 |
+
),
|
914 |
+
Tool.Argument(
|
915 |
+
name="query",
|
916 |
+
type=DataType.TEXT,
|
917 |
+
description="a query that can be a phrase, a sentence",
|
918 |
+
),
|
919 |
+
],
|
920 |
+
returns=[
|
921 |
+
Tool.Argument(
|
922 |
+
name="bbox",
|
923 |
+
type=DataType.BBOX,
|
924 |
+
description="the detected bounding boxes for ",
|
925 |
+
)
|
926 |
+
],
|
927 |
+
model=visual_grounding,
|
928 |
+
),
|
929 |
+
Tool(
|
930 |
+
name="optical_character_recognition",
|
931 |
+
description="Optical Character Recognition (OCR) is the process that converts an image of text into a machine-readable text format.",
|
932 |
+
domain=Tool.Domain.IMAGE_PERCEPTION,
|
933 |
+
args=[
|
934 |
+
Tool.Argument(
|
935 |
+
name="image",
|
936 |
+
type=DataType.IMAGE,
|
937 |
+
description="the image that need to be processed",
|
938 |
+
)
|
939 |
+
],
|
940 |
+
returns=[
|
941 |
+
Tool.Argument(
|
942 |
+
name="text",
|
943 |
+
type=DataType.TEXT,
|
944 |
+
description="the recognized text",
|
945 |
+
)
|
946 |
+
],
|
947 |
+
model=ocr,
|
948 |
+
),
|
949 |
+
]
|
950 |
+
|
951 |
+
GENERAL_TOOLS = [
|
952 |
+
Tool(
|
953 |
+
name="select_category",
|
954 |
+
description="select the target classes in category list with the given condition.",
|
955 |
+
domain=Tool.Domain.GENERAL,
|
956 |
+
usages=["pick out the objects with the same type"],
|
957 |
+
args=[
|
958 |
+
Tool.Argument(
|
959 |
+
name="category_list",
|
960 |
+
type=DataType.CATEGORY,
|
961 |
+
description="the list to be processed",
|
962 |
+
),
|
963 |
+
Tool.Argument(
|
964 |
+
name="condition",
|
965 |
+
type=DataType.TEXT,
|
966 |
+
description="the condition to select objects",
|
967 |
+
),
|
968 |
+
],
|
969 |
+
returns=[
|
970 |
+
Tool.Argument(
|
971 |
+
name="target_category_result",
|
972 |
+
type=DataType.CATEGORY,
|
973 |
+
description="the selected list",
|
974 |
+
)
|
975 |
+
],
|
976 |
+
model=select,
|
977 |
+
),
|
978 |
+
Tool(
|
979 |
+
name="select_bbox",
|
980 |
+
description="select the bounding boxes with the given condition.",
|
981 |
+
domain=Tool.Domain.GENERAL,
|
982 |
+
usages=["filter out the bounding boxes with the same type"],
|
983 |
+
args=[
|
984 |
+
Tool.Argument(
|
985 |
+
name="bbox_list",
|
986 |
+
type=DataType.BBOX,
|
987 |
+
description="the bounding box list to be processed",
|
988 |
+
),
|
989 |
+
Tool.Argument(
|
990 |
+
name="condition",
|
991 |
+
type=DataType.TEXT,
|
992 |
+
description="the condition to select objects",
|
993 |
+
),
|
994 |
+
],
|
995 |
+
returns=[
|
996 |
+
Tool.Argument(
|
997 |
+
name="result",
|
998 |
+
type=DataType.BBOX,
|
999 |
+
description="the selected bbox list",
|
1000 |
+
)
|
1001 |
+
],
|
1002 |
+
model=select,
|
1003 |
+
),
|
1004 |
+
Tool(
|
1005 |
+
name="select_mask",
|
1006 |
+
description="select the masks with the given condition.",
|
1007 |
+
domain=Tool.Domain.GENERAL,
|
1008 |
+
args=[
|
1009 |
+
Tool.Argument(
|
1010 |
+
name="mask_list",
|
1011 |
+
type=DataType.MASK,
|
1012 |
+
description="the list to be processed",
|
1013 |
+
),
|
1014 |
+
Tool.Argument(
|
1015 |
+
name="condition",
|
1016 |
+
type=DataType.TEXT,
|
1017 |
+
description="the condition to select objects",
|
1018 |
+
),
|
1019 |
+
],
|
1020 |
+
returns=[
|
1021 |
+
Tool.Argument(
|
1022 |
+
name="result",
|
1023 |
+
type=DataType.MASK,
|
1024 |
+
description="the selected mask list",
|
1025 |
+
)
|
1026 |
+
],
|
1027 |
+
model=select,
|
1028 |
+
),
|
1029 |
+
Tool(
|
1030 |
+
name="count_categories",
|
1031 |
+
description="count target categories in the given list.",
|
1032 |
+
domain=Tool.Domain.GENERAL,
|
1033 |
+
args=[
|
1034 |
+
Tool.Argument(
|
1035 |
+
name="category_list",
|
1036 |
+
type=DataType.CATEGORY,
|
1037 |
+
description="the list to be processed",
|
1038 |
+
),
|
1039 |
+
],
|
1040 |
+
returns=[
|
1041 |
+
Tool.Argument(
|
1042 |
+
name="length",
|
1043 |
+
type=DataType.TEXT,
|
1044 |
+
description="the length of the given list, return in the string format."
|
1045 |
+
"Example: The length of the given list is 10",
|
1046 |
+
)
|
1047 |
+
],
|
1048 |
+
model=count,
|
1049 |
+
),
|
1050 |
+
Tool(
|
1051 |
+
name="count_objects",
|
1052 |
+
description="count target objects in the given list. It is useful when you want to count the number of objects in the image",
|
1053 |
+
domain=Tool.Domain.GENERAL,
|
1054 |
+
args=[
|
1055 |
+
Tool.Argument(
|
1056 |
+
name="bbox_list",
|
1057 |
+
type=DataType.BBOX,
|
1058 |
+
description="the bounding box list to be counted",
|
1059 |
+
),
|
1060 |
+
],
|
1061 |
+
returns=[
|
1062 |
+
Tool.Argument(
|
1063 |
+
name="length",
|
1064 |
+
type=DataType.TEXT,
|
1065 |
+
description="the length of the given list, return in the string format."
|
1066 |
+
"Example: The length of the given list is 10",
|
1067 |
+
)
|
1068 |
+
],
|
1069 |
+
model=count,
|
1070 |
+
),
|
1071 |
+
Tool(
|
1072 |
+
name="count_masks",
|
1073 |
+
description="count target mask in the given list.",
|
1074 |
+
domain=Tool.Domain.GENERAL,
|
1075 |
+
args=[
|
1076 |
+
Tool.Argument(
|
1077 |
+
name="mask_list",
|
1078 |
+
type=DataType.MASK,
|
1079 |
+
description="the list to be processed",
|
1080 |
+
),
|
1081 |
+
],
|
1082 |
+
returns=[
|
1083 |
+
Tool.Argument(
|
1084 |
+
name="length",
|
1085 |
+
type=DataType.TEXT,
|
1086 |
+
description="the length of the given list, return in the string format."
|
1087 |
+
"Example: The length of the given list is 10",
|
1088 |
+
)
|
1089 |
+
],
|
1090 |
+
model=count,
|
1091 |
+
),
|
1092 |
+
]
|
1093 |
+
|
1094 |
+
VIDEO_TOOLS = [
|
1095 |
+
# VIDEO
|
1096 |
+
Tool(
|
1097 |
+
name="video_captioning",
|
1098 |
+
description='Generate a caption or description for video. It can generate a detailed description that can be used for video perception and video generation. For example: a) you can use this tool when you want to know what happened in the video"; and b) when you want to generate tags for input video, you can use translate description obtained from `image_captioning` into tags.',
|
1099 |
+
domain=Tool.Domain.VIDEO_PERCEPTION,
|
1100 |
+
args=[
|
1101 |
+
Tool.Argument(
|
1102 |
+
name="video",
|
1103 |
+
type=DataType.VIDEO,
|
1104 |
+
description="the video to be captioned.",
|
1105 |
+
),
|
1106 |
+
],
|
1107 |
+
returns=[
|
1108 |
+
Tool.Argument(
|
1109 |
+
name="caption",
|
1110 |
+
type=DataType.TEXT,
|
1111 |
+
description="the caption or description of input video.",
|
1112 |
+
)
|
1113 |
+
],
|
1114 |
+
model=video_captioning,
|
1115 |
+
),
|
1116 |
+
Tool(
|
1117 |
+
name="image_audio_to_video",
|
1118 |
+
description="Generate a video with speech to introduce the image.",
|
1119 |
+
domain=Tool.Domain.VIDEO_GENERATION,
|
1120 |
+
args=[
|
1121 |
+
Tool.Argument(
|
1122 |
+
name="image",
|
1123 |
+
type=DataType.IMAGE,
|
1124 |
+
description="The input image to be introduced.",
|
1125 |
+
),
|
1126 |
+
Tool.Argument(
|
1127 |
+
name="audio",
|
1128 |
+
type=DataType.AUDIO,
|
1129 |
+
description="The audio contained the speech of image description.",
|
1130 |
+
),
|
1131 |
+
],
|
1132 |
+
returns=[
|
1133 |
+
Tool.Argument(
|
1134 |
+
name="video",
|
1135 |
+
type=DataType.VIDEO,
|
1136 |
+
description="Generated video that can introduce the image with speech",
|
1137 |
+
)
|
1138 |
+
],
|
1139 |
+
model=image_audio_to_video,
|
1140 |
+
),
|
1141 |
+
Tool(
|
1142 |
+
name="image_to_video",
|
1143 |
+
description="Generate a video based on image.",
|
1144 |
+
domain=Tool.Domain.VIDEO_GENERATION,
|
1145 |
+
args=[
|
1146 |
+
Tool.Argument(
|
1147 |
+
name="image",
|
1148 |
+
type=DataType.IMAGE,
|
1149 |
+
description="The input image.",
|
1150 |
+
),
|
1151 |
+
],
|
1152 |
+
returns=[
|
1153 |
+
Tool.Argument(
|
1154 |
+
name="video",
|
1155 |
+
type=DataType.VIDEO,
|
1156 |
+
description="Generated video from the input image.",
|
1157 |
+
)
|
1158 |
+
],
|
1159 |
+
model=image_to_video,
|
1160 |
+
),
|
1161 |
+
Tool(
|
1162 |
+
name="video_to_webpage",
|
1163 |
+
description="Generate a web page to promote and introduce the video.",
|
1164 |
+
domain=Tool.Domain.VIDEO_PROCESSING,
|
1165 |
+
args=[
|
1166 |
+
Tool.Argument(
|
1167 |
+
name="video",
|
1168 |
+
type=DataType.VIDEO,
|
1169 |
+
description="The input image to be introduced.",
|
1170 |
+
),
|
1171 |
+
Tool.Argument(
|
1172 |
+
name="title",
|
1173 |
+
type=DataType.TITLE,
|
1174 |
+
description="The title of video.",
|
1175 |
+
),
|
1176 |
+
Tool.Argument(
|
1177 |
+
name="tags",
|
1178 |
+
type=DataType.TAGS,
|
1179 |
+
description="The tags of video.",
|
1180 |
+
),
|
1181 |
+
Tool.Argument(
|
1182 |
+
name="description",
|
1183 |
+
type=DataType.TEXT,
|
1184 |
+
description="The description of video.",
|
1185 |
+
),
|
1186 |
+
],
|
1187 |
+
returns=[
|
1188 |
+
Tool.Argument(
|
1189 |
+
name="html_code",
|
1190 |
+
type=DataType.HTML,
|
1191 |
+
description="Generated HTML webpage with code that can introduce the video with speech.",
|
1192 |
+
)
|
1193 |
+
],
|
1194 |
+
model=video_to_webpage,
|
1195 |
+
),
|
1196 |
+
Tool(
|
1197 |
+
name="dub_video",
|
1198 |
+
description="Dub the input video with given audio track.",
|
1199 |
+
domain=Tool.Domain.VIDEO_EDITING,
|
1200 |
+
args=[
|
1201 |
+
Tool.Argument(
|
1202 |
+
name="video",
|
1203 |
+
type=DataType.VIDEO,
|
1204 |
+
description="The input image to be introduced.",
|
1205 |
+
),
|
1206 |
+
Tool.Argument(
|
1207 |
+
name="audio",
|
1208 |
+
type=DataType.AUDIO,
|
1209 |
+
description="The audio of video.",
|
1210 |
+
),
|
1211 |
+
],
|
1212 |
+
returns=[
|
1213 |
+
Tool.Argument(
|
1214 |
+
name="video",
|
1215 |
+
type=DataType.VIDEO,
|
1216 |
+
description="Output video with designated audio.",
|
1217 |
+
)
|
1218 |
+
],
|
1219 |
+
model=dub_video,
|
1220 |
+
),
|
1221 |
+
Tool(
|
1222 |
+
name="text_to_video",
|
1223 |
+
description="It takes as input a natural language description and produces a video matching that description",
|
1224 |
+
domain=Tool.Domain.VIDEO_GENERATION,
|
1225 |
+
args=[
|
1226 |
+
Tool.Argument(
|
1227 |
+
name="prompt",
|
1228 |
+
type=DataType.TEXT,
|
1229 |
+
description="the text describing the image",
|
1230 |
+
)
|
1231 |
+
],
|
1232 |
+
returns=[
|
1233 |
+
Tool.Argument(
|
1234 |
+
name="video",
|
1235 |
+
type=DataType.VIDEO,
|
1236 |
+
description="the generated video",
|
1237 |
+
)
|
1238 |
+
],
|
1239 |
+
model=text_to_video,
|
1240 |
+
),
|
1241 |
+
]
|
1242 |
+
|
1243 |
+
AUDIO_TOOLS = [
|
1244 |
+
# AUDIO
|
1245 |
+
Tool(
|
1246 |
+
name="text_to_music",
|
1247 |
+
description="Generate music condioned on input text/prompt. For example, you can use this tool when you want to generate music for a poem, generate a piece of music from image.",
|
1248 |
+
domain=Tool.Domain.AUDIO_GENERATION,
|
1249 |
+
args=[
|
1250 |
+
Tool.Argument(
|
1251 |
+
name="text",
|
1252 |
+
type=DataType.TEXT,
|
1253 |
+
description="Input text for music generation.",
|
1254 |
+
),
|
1255 |
+
],
|
1256 |
+
returns=[
|
1257 |
+
Tool.Argument(
|
1258 |
+
name="music",
|
1259 |
+
type=DataType.AUDIO,
|
1260 |
+
description="Generated music conditioned on text.",
|
1261 |
+
)
|
1262 |
+
],
|
1263 |
+
model=text_to_music,
|
1264 |
+
),
|
1265 |
+
Tool(
|
1266 |
+
name="text_to_speech",
|
1267 |
+
description="Create natural-sounding speech from text, where the speech can be generated in multiple languages and for multiple speakers",
|
1268 |
+
domain=Tool.Domain.AUDIO_GENERATION,
|
1269 |
+
args=[
|
1270 |
+
Tool.Argument(
|
1271 |
+
name="text",
|
1272 |
+
type=DataType.TEXT,
|
1273 |
+
description="The input text that will be translated into speech.",
|
1274 |
+
),
|
1275 |
+
],
|
1276 |
+
returns=[
|
1277 |
+
Tool.Argument(
|
1278 |
+
name="speech",
|
1279 |
+
type=DataType.AUDIO,
|
1280 |
+
description="Generated speech or voice conditioned on text.",
|
1281 |
+
)
|
1282 |
+
],
|
1283 |
+
model=text_to_speech,
|
1284 |
+
),
|
1285 |
+
Tool(
|
1286 |
+
name="audio_classification",
|
1287 |
+
description="Audio classification is the task of assigning a label or class to a given audio. It can be used for recognizing which command a user is giving or the emotion of a statement, as well as identifying a speaker.",
|
1288 |
+
domain=Tool.Domain.AUDIO_PERCEPTION,
|
1289 |
+
args=[
|
1290 |
+
Tool.Argument(
|
1291 |
+
name="audio",
|
1292 |
+
type=DataType.AUDIO,
|
1293 |
+
description="The input audio that will be classified.",
|
1294 |
+
),
|
1295 |
+
],
|
1296 |
+
returns=[
|
1297 |
+
Tool.Argument(
|
1298 |
+
name="speech",
|
1299 |
+
type=DataType.CATEGORY,
|
1300 |
+
description="The recognized categories in json format.",
|
1301 |
+
)
|
1302 |
+
],
|
1303 |
+
model=audio_classification,
|
1304 |
+
),
|
1305 |
+
]
|
1306 |
+
|
1307 |
+
NLP_TOOLS = [
|
1308 |
+
# Text
|
1309 |
+
Tool(
|
1310 |
+
name="text_to_text_generation",
|
1311 |
+
description="Text to text generation. It can be used for sentence acceptability judgment, Sentiment analysis, Paraphrasing/sentence similarity, Natural language inference, Sentence completion, Word sense disambiguation, Question answering.",
|
1312 |
+
domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
|
1313 |
+
args=[
|
1314 |
+
Tool.Argument(
|
1315 |
+
name="text",
|
1316 |
+
type=DataType.TEXT,
|
1317 |
+
description="The input text",
|
1318 |
+
),
|
1319 |
+
],
|
1320 |
+
returns=[
|
1321 |
+
Tool.Argument(
|
1322 |
+
name="answer",
|
1323 |
+
type=DataType.TEXT,
|
1324 |
+
description="Generated answer for given input.",
|
1325 |
+
)
|
1326 |
+
],
|
1327 |
+
model=text_to_text_generation,
|
1328 |
+
),
|
1329 |
+
Tool(
|
1330 |
+
name="title_generation",
|
1331 |
+
description="Generate a title for given text.",
|
1332 |
+
domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
|
1333 |
+
args=[
|
1334 |
+
Tool.Argument(
|
1335 |
+
name="text",
|
1336 |
+
type=DataType.TEXT,
|
1337 |
+
description="The input text",
|
1338 |
+
),
|
1339 |
+
],
|
1340 |
+
returns=[
|
1341 |
+
Tool.Argument(
|
1342 |
+
name="title",
|
1343 |
+
type=DataType.TITLE,
|
1344 |
+
description="Generated title based given sentences.",
|
1345 |
+
)
|
1346 |
+
],
|
1347 |
+
model=title_generation,
|
1348 |
+
),
|
1349 |
+
Tool(
|
1350 |
+
name="openai_chat_model",
|
1351 |
+
description="Answer the question by Large Language Model.",
|
1352 |
+
domain=Tool.Domain.QUESTION_ANSWERING,
|
1353 |
+
args=[
|
1354 |
+
Tool.Argument(
|
1355 |
+
name="input_msg",
|
1356 |
+
type=DataType.TEXT,
|
1357 |
+
description="The input text",
|
1358 |
+
)
|
1359 |
+
],
|
1360 |
+
returns=[
|
1361 |
+
Tool.Argument(
|
1362 |
+
name="answer",
|
1363 |
+
type=DataType.TEXT,
|
1364 |
+
description="Generated answer based given text.",
|
1365 |
+
)
|
1366 |
+
],
|
1367 |
+
model=openai_chat_model,
|
1368 |
+
),
|
1369 |
+
Tool(
|
1370 |
+
name="summarization",
|
1371 |
+
description="Summarize sentences, long narratives, articles, papers, textbooks.",
|
1372 |
+
domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
|
1373 |
+
args=[
|
1374 |
+
Tool.Argument(
|
1375 |
+
name="text",
|
1376 |
+
type=DataType.TEXT,
|
1377 |
+
description="The input text to be Summarized.",
|
1378 |
+
),
|
1379 |
+
],
|
1380 |
+
returns=[
|
1381 |
+
Tool.Argument(
|
1382 |
+
name="summarized_text",
|
1383 |
+
type=DataType.TEXT,
|
1384 |
+
description="Summarized text.",
|
1385 |
+
)
|
1386 |
+
],
|
1387 |
+
model=summarization,
|
1388 |
+
),
|
1389 |
+
Tool(
|
1390 |
+
name="text_to_tags",
|
1391 |
+
description="Predict the tags of text, article and papers by using the their textual content as input",
|
1392 |
+
domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
|
1393 |
+
args=[
|
1394 |
+
Tool.Argument(
|
1395 |
+
name="text",
|
1396 |
+
type=DataType.TEXT,
|
1397 |
+
description="The input text to be Summarized.",
|
1398 |
+
),
|
1399 |
+
],
|
1400 |
+
returns=[
|
1401 |
+
Tool.Argument(
|
1402 |
+
name="tags",
|
1403 |
+
type=DataType.TAGS,
|
1404 |
+
description="The extracted tags from input text",
|
1405 |
+
)
|
1406 |
+
],
|
1407 |
+
model=text_to_tags,
|
1408 |
+
),
|
1409 |
+
Tool(
|
1410 |
+
name="named_entity_recognition",
|
1411 |
+
description="Named-entity recognition (NER) (also known as (named) entity identification, entity chunking, and entity extraction) is a subtask of information extraction that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc.",
|
1412 |
+
domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
|
1413 |
+
args=[
|
1414 |
+
Tool.Argument(
|
1415 |
+
name="text",
|
1416 |
+
type=DataType.TEXT,
|
1417 |
+
description="The input text from which the named entities are extracted",
|
1418 |
+
),
|
1419 |
+
],
|
1420 |
+
returns=[
|
1421 |
+
Tool.Argument(
|
1422 |
+
name="tags",
|
1423 |
+
type=DataType.TAGS,
|
1424 |
+
description="The extracted entities",
|
1425 |
+
)
|
1426 |
+
],
|
1427 |
+
model=None,
|
1428 |
+
),
|
1429 |
+
Tool(
|
1430 |
+
name="sentiment_analysis",
|
1431 |
+
description="Sentiment analysis is the process of analyzing digital text to determine if the emotional tone of the message is positive, negative, or neutral.",
|
1432 |
+
domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
|
1433 |
+
args=[
|
1434 |
+
Tool.Argument(
|
1435 |
+
name="text",
|
1436 |
+
type=DataType.TEXT,
|
1437 |
+
description="The input text to be analyzed",
|
1438 |
+
),
|
1439 |
+
],
|
1440 |
+
returns=[
|
1441 |
+
Tool.Argument(
|
1442 |
+
name="text",
|
1443 |
+
type=DataType.TEXT,
|
1444 |
+
description="The sentiment of text",
|
1445 |
+
)
|
1446 |
+
],
|
1447 |
+
model=sentiment_analysis,
|
1448 |
+
),
|
1449 |
+
Tool(
|
1450 |
+
name="extract_location",
|
1451 |
+
description="Extracts the locale name from the text. For example, if the text is 'what is the weather in Beijing', the tool will return 'Beijing'. If the text is 'Samuel ppops in a happy plce called Berlin which happens to be Kazakhstan', the tool will return 'Berlin,Kazakhstan'.",
|
1452 |
+
domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
|
1453 |
+
args=[
|
1454 |
+
Tool.Argument(
|
1455 |
+
name="text",
|
1456 |
+
type=DataType.TEXT,
|
1457 |
+
description="The input text to be analyzed",
|
1458 |
+
),
|
1459 |
+
],
|
1460 |
+
returns=[
|
1461 |
+
Tool.Argument(
|
1462 |
+
name="location",
|
1463 |
+
type=DataType.LOCATION,
|
1464 |
+
description="The sentiment of text",
|
1465 |
+
)
|
1466 |
+
],
|
1467 |
+
model=extract_location,
|
1468 |
+
),
|
1469 |
+
Tool(
|
1470 |
+
name="summarize_weather_condition",
|
1471 |
+
description="Translate the json formatted weather information into the text that human can understand. For example, when you want to generate a new image based on weather information",
|
1472 |
+
domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
|
1473 |
+
args=[
|
1474 |
+
Tool.Argument(
|
1475 |
+
name="weather",
|
1476 |
+
type=DataType.WEATHER,
|
1477 |
+
description="weather condition",
|
1478 |
+
)
|
1479 |
+
],
|
1480 |
+
returns=[
|
1481 |
+
Tool.Argument(
|
1482 |
+
name="weather_summary",
|
1483 |
+
type=DataType.TEXT,
|
1484 |
+
description="the weather summary",
|
1485 |
+
)
|
1486 |
+
],
|
1487 |
+
model=summarize_weather_condition,
|
1488 |
+
),
|
1489 |
+
]
|
1490 |
+
|
1491 |
+
TOOLS = (
|
1492 |
+
QUESTION_ANSWERING_TOOLS
|
1493 |
+
+ IMAGE_CAPTIONING_TOOLS
|
1494 |
+
+ IMAGE_EDITING_TOOLS
|
1495 |
+
+ IMAGE_GENERATION_TOOLS
|
1496 |
+
+ IMAGE_TRANSFORM_TOOLS
|
1497 |
+
+ IMAGE_PERCEPTION_TOOLS
|
1498 |
+
+ GENERAL_TOOLS
|
1499 |
+
+ VIDEO_TOOLS
|
1500 |
+
+ AUDIO_TOOLS
|
1501 |
+
+ NLP_TOOLS
|
1502 |
+
)
|
1503 |
+
TOOLS = {tool.name: tool for tool in TOOLS}
|
1504 |
+
|
1505 |
+
if __name__ == "__main__":
|
1506 |
+
tools = []
|
1507 |
+
for tool in TOOLS.values():
|
1508 |
+
tools.append(tool.dict())
|
1509 |
+
import json
|
1510 |
+
|
1511 |
+
with open("tools.json", "w") as f:
|
1512 |
+
json.dump(tools, f, indent=4)
|
cllm/agents/container.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
import os.path as osp
|
6 |
+
from pathlib import Path
|
7 |
+
import json
|
8 |
+
from .base import DataType
|
9 |
+
from cllm.utils import get_real_path
|
10 |
+
|
11 |
+
|
12 |
+
# sys.path.insert(0, sys.path[0] + "/../")
|
13 |
+
FILE_EXT = {
|
14 |
+
"image": ["png", "jpeg", "jpg", "gif", "bmp", "tiff", "webp"],
|
15 |
+
"video": ["mp4", "mov", "avi", "mkv"],
|
16 |
+
"audio": ["wav", "mp3"],
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
class Container:
|
21 |
+
def __init__(self, name, rtype, value) -> None:
|
22 |
+
self.name = name
|
23 |
+
self.rtype = rtype
|
24 |
+
self.value = value
|
25 |
+
|
26 |
+
def to_chatbot(self):
|
27 |
+
pass
|
28 |
+
|
29 |
+
def __str__(self):
|
30 |
+
pass
|
31 |
+
|
32 |
+
def __repr__(self) -> str:
|
33 |
+
return str(self)
|
34 |
+
|
35 |
+
|
36 |
+
class File(Container):
|
37 |
+
def to_chatbot(self):
|
38 |
+
return str(self.value)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def filename(self):
|
42 |
+
return os.path.basename(self.value)
|
43 |
+
|
44 |
+
def __str__(self):
|
45 |
+
return f"`{self.filename}`"
|
46 |
+
|
47 |
+
|
48 |
+
class HTML(File):
|
49 |
+
def to_chatbot(self):
|
50 |
+
return str(self.value)
|
51 |
+
|
52 |
+
def __str__(self):
|
53 |
+
return f"`{self.filename}`"
|
54 |
+
|
55 |
+
|
56 |
+
class Image(File):
|
57 |
+
def __str__(self):
|
58 |
+
return f"`{self.filename}`"
|
59 |
+
|
60 |
+
|
61 |
+
class Video(File):
|
62 |
+
def __str__(self):
|
63 |
+
return f"`{self.filename}`"
|
64 |
+
|
65 |
+
|
66 |
+
class Audio(File):
|
67 |
+
def __str__(self):
|
68 |
+
return f"`{self.filename}`"
|
69 |
+
|
70 |
+
|
71 |
+
class Text(Container):
|
72 |
+
def to_chatbot(self):
|
73 |
+
if isinstance(self.value, str):
|
74 |
+
return self.value
|
75 |
+
elif isinstance(self.value, (list, tuple, dict)):
|
76 |
+
return json.dumps(self.value, indent=2)
|
77 |
+
return self.value
|
78 |
+
|
79 |
+
def __str__(self):
|
80 |
+
if isinstance(self.value, (list, dict)):
|
81 |
+
return json.dumps(self.value)
|
82 |
+
elif isinstance(self.value, str):
|
83 |
+
return self.value
|
84 |
+
return str(self.value)
|
85 |
+
|
86 |
+
|
87 |
+
def auto_type(name, rtype, value):
|
88 |
+
if value is None:
|
89 |
+
return None
|
90 |
+
if "image" in str(rtype):
|
91 |
+
return Image(name, rtype, get_real_path(value))
|
92 |
+
if DataType.VIDEO == rtype:
|
93 |
+
return Video(name, rtype, get_real_path(value))
|
94 |
+
if DataType.AUDIO == rtype:
|
95 |
+
return Audio(name, rtype, get_real_path(value))
|
96 |
+
if DataType.HTML == rtype:
|
97 |
+
return HTML(name, rtype, get_real_path(value))
|
98 |
+
return Text(name, rtype, value)
|
cllm/agents/tog/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .planner import Planner
|
2 |
+
from .controller import Controller
|
cllm/agents/tog/compiler.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
import ast
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
sys.path.append(os.getcwd())
|
7 |
+
from .cllm.agents.base import Action
|
8 |
+
|
9 |
+
|
10 |
+
class Parser:
|
11 |
+
def parse(self, plan) -> List[Action]:
|
12 |
+
# ignore indent
|
13 |
+
input = "\n".join([line.strip() for line in plan.split("\n")])
|
14 |
+
actions = []
|
15 |
+
for stmt in ast.parse(input).body:
|
16 |
+
if isinstance(stmt, ast.Assign):
|
17 |
+
assign: ast.Assign = stmt
|
18 |
+
output: ast.Name = assign.targets[0]
|
19 |
+
func_call: ast.Call = assign.value
|
20 |
+
func_name: ast.Name = func_call.func
|
21 |
+
kwargs: List[ast.keyword] = func_call.keywords
|
22 |
+
args = {}
|
23 |
+
for kwarg in kwargs:
|
24 |
+
k = kwarg.arg
|
25 |
+
if isinstance(kwarg.value, ast.Name):
|
26 |
+
v = kwarg.value.id
|
27 |
+
else:
|
28 |
+
v = ast.literal_eval(kwarg.value)
|
29 |
+
args[k] = v
|
30 |
+
action = Action(
|
31 |
+
tool_name=func_name.id, outputs=[output.id], inputs=args
|
32 |
+
)
|
33 |
+
actions.append(action)
|
34 |
+
return actions
|
35 |
+
|
36 |
+
|
37 |
+
class Compiler:
|
38 |
+
def __init__(self):
|
39 |
+
self.parser = Parser()
|
40 |
+
|
41 |
+
def compile(self, plan: Union[str, List[Union[Action, str]]]) -> List[Action]:
|
42 |
+
"""The input could be a plain string, a list of structured `Action`
|
43 |
+
or combination of structured `Action` or unstructured action string.
|
44 |
+
"""
|
45 |
+
actions = self.parse(plan)
|
46 |
+
actions = self.correct(actions)
|
47 |
+
return actions
|
48 |
+
|
49 |
+
def parse(self, plan) -> List[Action]:
|
50 |
+
if isinstance(plan, str):
|
51 |
+
return self.parser.parse(plan)
|
52 |
+
|
53 |
+
actions = []
|
54 |
+
for action in plan:
|
55 |
+
if isinstance(action, str):
|
56 |
+
action = self.parser.parse(action)[0]
|
57 |
+
actions.append(action)
|
58 |
+
|
59 |
+
return actions
|
60 |
+
|
61 |
+
def correct(self, actions):
|
62 |
+
return actions
|
cllm/agents/tog/controller.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
import logging
|
3 |
+
from typing import Tuple, List
|
4 |
+
import copy
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import json
|
8 |
+
from collections import OrderedDict
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
|
12 |
+
sys.path.append(os.getcwd())
|
13 |
+
from cllm.agents import container
|
14 |
+
from cllm.agents.builtin import BUILTIN_PLANS, load_builtin_plans
|
15 |
+
from cllm.agents.container import auto_type
|
16 |
+
from cllm.agents.base import DataType, NON_FILE_TYPES
|
17 |
+
|
18 |
+
from .interpretor import Interpretor
|
19 |
+
from .planner import Planner
|
20 |
+
from .responser import generate_response
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class Controller:
|
26 |
+
def __init__(self, stream=True, interpretor_kwargs={}):
|
27 |
+
self.stream = stream
|
28 |
+
self.planner = Planner(self.stream)
|
29 |
+
self.interpretor = Interpretor(**interpretor_kwargs)
|
30 |
+
self.SHORTCUT = "**Using builtin shortcut solution.**"
|
31 |
+
BUILTIN_PLANS.update(load_builtin_plans("builtin_plan.json"))
|
32 |
+
logger.info(BUILTIN_PLANS)
|
33 |
+
|
34 |
+
def plan(self, request: str, state: dict):
|
35 |
+
logger.info(request)
|
36 |
+
|
37 |
+
resource_memory = state.get("resources", {})
|
38 |
+
raw_solution = None
|
39 |
+
# shortcut for builtin plan
|
40 |
+
for trigger_prompt, _ in BUILTIN_PLANS.items():
|
41 |
+
if request == trigger_prompt:
|
42 |
+
return self.SHORTCUT
|
43 |
+
|
44 |
+
# dynamic execution
|
45 |
+
if raw_solution is None:
|
46 |
+
raw_solution = self.planner.plan(request, resource_memory)
|
47 |
+
return raw_solution
|
48 |
+
|
49 |
+
def parse_solution_from_stream(self, raw_solution):
|
50 |
+
return self.planner.parse(raw_solution)
|
51 |
+
|
52 |
+
def execute(self, raw_solution: str, state: dict):
|
53 |
+
resource_memory = state.get("resources")
|
54 |
+
request = state["request"]
|
55 |
+
solution = None
|
56 |
+
if raw_solution == self.SHORTCUT:
|
57 |
+
for trigger_prompt, builtin_plan in BUILTIN_PLANS.items():
|
58 |
+
if request == trigger_prompt:
|
59 |
+
solution = builtin_plan
|
60 |
+
solution = self._fill_args(solution, resource_memory)
|
61 |
+
else:
|
62 |
+
solution = self.planner.parse(raw_solution)
|
63 |
+
|
64 |
+
if not solution:
|
65 |
+
return None
|
66 |
+
try:
|
67 |
+
history_msgs = state.get("history_msgs")
|
68 |
+
return self.interpretor.interpret(solution, history_msgs)
|
69 |
+
except Exception as e:
|
70 |
+
traceback.print_exc()
|
71 |
+
return None
|
72 |
+
|
73 |
+
def reply(self, executed_plan: dict, outputs: list, state: dict):
|
74 |
+
error_response = [
|
75 |
+
auto_type(
|
76 |
+
"response",
|
77 |
+
DataType.TEXT,
|
78 |
+
"Sorry, I cannot understand your request due to an internal error.",
|
79 |
+
)
|
80 |
+
]
|
81 |
+
state = copy.deepcopy(state)
|
82 |
+
if (
|
83 |
+
executed_plan is None
|
84 |
+
or len(executed_plan) == 0
|
85 |
+
or outputs is None
|
86 |
+
or len(outputs) == 0
|
87 |
+
):
|
88 |
+
return error_response, state
|
89 |
+
resources = state.get("resources", OrderedDict())
|
90 |
+
for o in outputs:
|
91 |
+
if isinstance(o, container.File):
|
92 |
+
resources[str(o.filename)] = str(o.rtype)
|
93 |
+
state["resources"] = resources
|
94 |
+
response = generate_response(state["request"], executed_plan, outputs)
|
95 |
+
if len(response) == 0:
|
96 |
+
return error_response, state
|
97 |
+
logger.info(response)
|
98 |
+
return response, state
|
99 |
+
|
100 |
+
def run(self, task: str, state: dict) -> Tuple[List, str]:
|
101 |
+
try:
|
102 |
+
return self._run(task, state)
|
103 |
+
except:
|
104 |
+
traceback.print_exc()
|
105 |
+
logger.info(traceback.format_exc())
|
106 |
+
return [
|
107 |
+
auto_type(
|
108 |
+
"response",
|
109 |
+
DataType.TEXT,
|
110 |
+
"Sorry, I cannot understand your request due to an internal error.",
|
111 |
+
)
|
112 |
+
], "Internal Error"
|
113 |
+
|
114 |
+
def _run(self, task: str, state: dict) -> Tuple[List, str]:
|
115 |
+
logger.info(task)
|
116 |
+
BUILTIN_PLANS.update(load_builtin_plans("builtin_plan.json"))
|
117 |
+
logger.info(BUILTIN_PLANS)
|
118 |
+
resource_memory = state.get("resources", OrderedDict())
|
119 |
+
history_msgs = state.get("history_msgs", [])
|
120 |
+
plan = None
|
121 |
+
|
122 |
+
# shortcut for builtin plan
|
123 |
+
for trigger_prompt, builtin_plan in BUILTIN_PLANS.items():
|
124 |
+
if task == trigger_prompt:
|
125 |
+
plan = builtin_plan
|
126 |
+
plan = self._fill_args(plan, resource_memory)
|
127 |
+
|
128 |
+
# dynamic executation
|
129 |
+
if plan is None:
|
130 |
+
plan = self.planner.planning(task, resource_memory)
|
131 |
+
logger.info(plan)
|
132 |
+
|
133 |
+
executed_plan, output_files = self.interpretor.interpret(
|
134 |
+
plan, resource_memory, history_msgs
|
135 |
+
)
|
136 |
+
logger.info(output_files)
|
137 |
+
for o in output_files:
|
138 |
+
if isinstance(o, container.File):
|
139 |
+
resource_memory[o.filename] = str(o.rtype)
|
140 |
+
|
141 |
+
outputs = generate_response(task, executed_plan, output_files)
|
142 |
+
|
143 |
+
logger.info(outputs)
|
144 |
+
return outputs, executed_plan
|
145 |
+
|
146 |
+
def _fill_args(self, plan, memory):
|
147 |
+
plan = copy.deepcopy(plan)
|
148 |
+
latest_resource = OrderedDict()
|
149 |
+
for key, val in memory.items():
|
150 |
+
latest_resource[val] = key
|
151 |
+
|
152 |
+
for actions in plan:
|
153 |
+
for action in actions:
|
154 |
+
for key, val in action.inputs.items():
|
155 |
+
if "<TOOL-GENERATED>" not in val:
|
156 |
+
action.inputs[key] = latest_resource.get(val, val)
|
157 |
+
return plan
|
cllm/agents/tog/interpretor.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from traceback import print_exc
|
3 |
+
from typing import List, Dict
|
4 |
+
import os.path as osp
|
5 |
+
import io
|
6 |
+
import copy
|
7 |
+
import re
|
8 |
+
import uuid
|
9 |
+
from matplotlib.pyplot import isinteractive
|
10 |
+
|
11 |
+
from numpy import isin
|
12 |
+
import sys
|
13 |
+
import os
|
14 |
+
|
15 |
+
sys.path.append(os.getcwd())
|
16 |
+
from cllm.agents.base import Action, DataType, Tool, NON_FILE_TYPES
|
17 |
+
from cllm.agents.builtin import TOOLS
|
18 |
+
from cllm.agents.container import auto_type
|
19 |
+
from cllm.utils import get_real_path, get_root_dir, transform_msgs
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def code(source, type="py"):
|
25 |
+
return f"```{type}\n{source}\n```"
|
26 |
+
|
27 |
+
|
28 |
+
class Interpretor:
|
29 |
+
def __init__(self):
|
30 |
+
self.tools = TOOLS
|
31 |
+
self.non_file_types = NON_FILE_TYPES
|
32 |
+
|
33 |
+
def interpret(self, stages: List[List[Action]], history_msgs: List = []):
|
34 |
+
memory = {}
|
35 |
+
solution = copy.deepcopy(stages)
|
36 |
+
history_msgs = copy.deepcopy(history_msgs)
|
37 |
+
history_msgs = transform_msgs(history_msgs)
|
38 |
+
has_error = False
|
39 |
+
for actions in solution:
|
40 |
+
for action in actions:
|
41 |
+
tool = self.load_tool(name=action.tool_name)
|
42 |
+
tool_inputs = self.load_args(tool, action.inputs, memory)
|
43 |
+
tool_inputs["history_msgs"] = history_msgs
|
44 |
+
tool_inputs["root_dir"] = get_root_dir()
|
45 |
+
try:
|
46 |
+
tool_outputs = tool.model(**tool_inputs)
|
47 |
+
action.inputs = self._update_inputs(memory, action.inputs)
|
48 |
+
action.outputs, wrapped_outputs = self._update_output(
|
49 |
+
memory, action, tool_outputs, tool
|
50 |
+
)
|
51 |
+
logger.info(
|
52 |
+
"Call {}, args {}, return {}".format(
|
53 |
+
action.tool_name, action.inputs, action.outputs
|
54 |
+
)
|
55 |
+
)
|
56 |
+
executed_action = (
|
57 |
+
action.tool_name,
|
58 |
+
action.inputs,
|
59 |
+
action.outputs,
|
60 |
+
)
|
61 |
+
except FileNotFoundError as e:
|
62 |
+
print_exc()
|
63 |
+
tool_outputs = None
|
64 |
+
logger.error(f"Error when executing {action.tool_name}: {e}")
|
65 |
+
has_error = True
|
66 |
+
wrapped_outputs = []
|
67 |
+
executed_action = (
|
68 |
+
action.tool_name,
|
69 |
+
action.inputs,
|
70 |
+
f"FileNotFoundError: No such file or directory: {osp.basename(e.filename)}",
|
71 |
+
)
|
72 |
+
except Exception as e:
|
73 |
+
print_exc()
|
74 |
+
tool_outputs = None
|
75 |
+
has_error = True
|
76 |
+
logger.error(f"Error when executing {action.tool_name}: {e}")
|
77 |
+
wrapped_outputs = []
|
78 |
+
executed_action = (
|
79 |
+
action.tool_name,
|
80 |
+
action.inputs,
|
81 |
+
f"Internal error: {e}",
|
82 |
+
)
|
83 |
+
yield executed_action, solution, wrapped_outputs
|
84 |
+
if has_error:
|
85 |
+
return
|
86 |
+
|
87 |
+
def _update_output(self, memory, action, tool_outputs, tool):
|
88 |
+
outputs = []
|
89 |
+
wrapped_outputs = []
|
90 |
+
if action.outputs is not None:
|
91 |
+
if len(action.outputs) == 1:
|
92 |
+
tool_outputs = [tool_outputs]
|
93 |
+
for i, (arg_name, arg_value) in enumerate(
|
94 |
+
zip(action.outputs, tool_outputs)
|
95 |
+
):
|
96 |
+
memory[arg_name] = arg_value
|
97 |
+
if arg_value is None:
|
98 |
+
outputs.append(arg_value)
|
99 |
+
wrapped_outputs.append(
|
100 |
+
auto_type(
|
101 |
+
arg_name,
|
102 |
+
DataType.TEXT,
|
103 |
+
None,
|
104 |
+
)
|
105 |
+
)
|
106 |
+
continue
|
107 |
+
|
108 |
+
if isinstance(arg_value, (dict, list)):
|
109 |
+
arg_value = self.pretty_floats(arg_value)
|
110 |
+
|
111 |
+
if tool.returns[i].type in self.non_file_types:
|
112 |
+
outputs.append(arg_value)
|
113 |
+
wrapped_outputs.append(
|
114 |
+
auto_type(
|
115 |
+
arg_name,
|
116 |
+
tool.returns[i].type,
|
117 |
+
arg_value,
|
118 |
+
)
|
119 |
+
)
|
120 |
+
|
121 |
+
continue
|
122 |
+
|
123 |
+
transformed_output = self.transform_output(
|
124 |
+
action.inputs,
|
125 |
+
tool.name,
|
126 |
+
tool.args,
|
127 |
+
arg_value,
|
128 |
+
tool.returns[i].type,
|
129 |
+
)
|
130 |
+
|
131 |
+
outputs.append(transformed_output)
|
132 |
+
memory[arg_name] = transformed_output
|
133 |
+
if not isinstance(transformed_output, list):
|
134 |
+
wrapped_outputs.append(
|
135 |
+
auto_type(
|
136 |
+
arg_name,
|
137 |
+
tool.returns[i].type,
|
138 |
+
transformed_output,
|
139 |
+
)
|
140 |
+
)
|
141 |
+
continue
|
142 |
+
|
143 |
+
for output in transformed_output:
|
144 |
+
if DataType.MASK == tool.returns[i].type:
|
145 |
+
output = output if isinstance(output, str) else output["mask"]
|
146 |
+
wrapped_outputs.append(
|
147 |
+
auto_type(
|
148 |
+
arg_name,
|
149 |
+
tool.returns[i].type,
|
150 |
+
output if isinstance(output, str) else output["mask"],
|
151 |
+
)
|
152 |
+
)
|
153 |
+
return outputs, wrapped_outputs
|
154 |
+
|
155 |
+
def pretty_floats(self, obj):
|
156 |
+
if isinstance(obj, float):
|
157 |
+
return round(obj, 4)
|
158 |
+
elif isinstance(obj, dict):
|
159 |
+
return dict((k, self.pretty_floats(v)) for k, v in obj.items())
|
160 |
+
elif isinstance(obj, (list, tuple)):
|
161 |
+
return list(map(self.pretty_floats, obj))
|
162 |
+
return obj
|
163 |
+
|
164 |
+
def _update_inputs(self, memory, action_inputs):
|
165 |
+
action_inputs = copy.deepcopy(action_inputs)
|
166 |
+
for key, value in action_inputs.items():
|
167 |
+
if "<TOOL-GENERATED>" in value:
|
168 |
+
action_inputs[key] = memory.get(value, value)
|
169 |
+
elif "<GENERATED>" in value:
|
170 |
+
action_inputs[key] = memory.get(value, value)
|
171 |
+
|
172 |
+
return action_inputs
|
173 |
+
|
174 |
+
def gen_filename(self, too_name, resource_type):
|
175 |
+
def to_camelcase(s):
|
176 |
+
res = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), s)
|
177 |
+
res = res[0].upper() + res[1:]
|
178 |
+
return res
|
179 |
+
|
180 |
+
if resource_type == DataType.VIDEO:
|
181 |
+
ext = "mp4"
|
182 |
+
elif resource_type == DataType.AUDIO:
|
183 |
+
ext = "wav"
|
184 |
+
elif resource_type == DataType.HTML:
|
185 |
+
ext = "html"
|
186 |
+
else:
|
187 |
+
ext = "png"
|
188 |
+
too_name = too_name.replace("_to_", "2_")
|
189 |
+
too_name = to_camelcase(too_name)
|
190 |
+
this_file_id = str(uuid.uuid4())[:6]
|
191 |
+
type_str = str(resource_type).split(".")[-1]
|
192 |
+
return f"{this_file_id}_{type_str}.{ext}"
|
193 |
+
|
194 |
+
def _save_resource(self, file_name, resource, resource_type):
|
195 |
+
if isinstance(resource, dict):
|
196 |
+
if "mask" in resource:
|
197 |
+
resource = resource["mask"]
|
198 |
+
if resource_type == DataType.HTML:
|
199 |
+
with open(get_real_path(file_name), "w") as fout:
|
200 |
+
fout.write(resource)
|
201 |
+
elif resource is not None:
|
202 |
+
if isinstance(resource, io.BufferedReader):
|
203 |
+
resource = resource.read()
|
204 |
+
with open(get_real_path(file_name), "wb") as fout:
|
205 |
+
fout.write(resource)
|
206 |
+
else:
|
207 |
+
return None
|
208 |
+
|
209 |
+
def transform_output(
|
210 |
+
self, action_inputs, tool_name, tool_args, tool_output, output_type
|
211 |
+
):
|
212 |
+
if output_type != DataType.MASK:
|
213 |
+
if isinstance(tool_output, list):
|
214 |
+
results = []
|
215 |
+
for output in tool_output:
|
216 |
+
file_name = self.gen_filename(tool_name, output_type)
|
217 |
+
self._save_resource(file_name, output, output_type)
|
218 |
+
results.append(file_name)
|
219 |
+
return results
|
220 |
+
else:
|
221 |
+
file_name = self.gen_filename(tool_name, output_type)
|
222 |
+
self._save_resource(file_name, tool_output, output_type)
|
223 |
+
return file_name
|
224 |
+
|
225 |
+
tool_output = copy.deepcopy(tool_output)
|
226 |
+
if isinstance(tool_output, list):
|
227 |
+
for output in tool_output:
|
228 |
+
if isinstance(output["mask"], str):
|
229 |
+
continue
|
230 |
+
|
231 |
+
file_name = self.gen_filename(tool_name, output_type)
|
232 |
+
self._save_resource(file_name, output, output_type)
|
233 |
+
output["mask"] = file_name
|
234 |
+
elif isinstance(tool_output, bytes):
|
235 |
+
file_name = self.gen_filename(tool_name, output_type)
|
236 |
+
self._save_resource(file_name, tool_output, output_type)
|
237 |
+
tool_output = file_name
|
238 |
+
elif tool_output is None:
|
239 |
+
pass
|
240 |
+
else:
|
241 |
+
raise RuntimeError("Wrong type.")
|
242 |
+
|
243 |
+
return tool_output
|
244 |
+
|
245 |
+
def load_tool(self, name):
|
246 |
+
return self.tools[name]
|
247 |
+
|
248 |
+
def load_args(self, tool: Tool, action_inputs, memory):
|
249 |
+
real_args = {}
|
250 |
+
for item in tool.args:
|
251 |
+
arg_name = item.name
|
252 |
+
arg_value = action_inputs[arg_name]
|
253 |
+
if "<GENERATED>" in arg_value or "<TOOL-GENERATED>" in arg_value:
|
254 |
+
assert arg_value in memory, print(f"Unknown {arg_name}: {arg_value}")
|
255 |
+
real_args[arg_name] = memory[arg_value]
|
256 |
+
else:
|
257 |
+
real_args[arg_name] = arg_value
|
258 |
+
return real_args
|
259 |
+
|
260 |
+
@property
|
261 |
+
def variables(self):
|
262 |
+
return {k: v for k, v in self.memory.items() if k not in TOOLS and k != "print"}
|
cllm/agents/tog/planner.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
sys.path.append(os.getcwd())
|
8 |
+
from ..base import Action, NON_FILE_TYPES
|
9 |
+
|
10 |
+
# from cllm.services.tog import TaskSolver, TaskDecomposer, config
|
11 |
+
# from cllm.services.nlp.llms import ChatOpenAI, MessageMemory
|
12 |
+
from cllm.services.tog.api import tog, task_decomposer
|
13 |
+
from collections import OrderedDict
|
14 |
+
import copy
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class Planner:
|
21 |
+
def __init__(
|
22 |
+
self, streaming=False, backend="remote", device="cuda:0", **llm_kwargs
|
23 |
+
):
|
24 |
+
self.streaming = streaming
|
25 |
+
if backend == "local":
|
26 |
+
pass
|
27 |
+
# self.cfg = config
|
28 |
+
# self.device = device
|
29 |
+
# self.mem = MessageMemory(**self.cfg.memory)
|
30 |
+
# self.llm = ChatOpenAI(temperature=0.2, **llm_kwargs)
|
31 |
+
# self.tog = TaskSolver(self.llm, self.cfg.task_solver_config, device).solve
|
32 |
+
# self.decomposer = TaskDecomposer(device, self.cfg.task_decomposer_cfg).solve
|
33 |
+
elif backend == "remote":
|
34 |
+
self.decomposer = task_decomposer
|
35 |
+
self.tog = tog
|
36 |
+
else:
|
37 |
+
raise ValueError("Backend should be chosen from [remote, local]")
|
38 |
+
|
39 |
+
def _find_latest_resource(self, resources, type):
|
40 |
+
for key, val in list(resources.items())[::-1]:
|
41 |
+
if val == type:
|
42 |
+
return key
|
43 |
+
return None
|
44 |
+
|
45 |
+
def _check_task_decomposition(
|
46 |
+
self, task_decomposition: str | list, available_resources: dict
|
47 |
+
):
|
48 |
+
copy_task_decomposition = copy.deepcopy(task_decomposition)
|
49 |
+
available_resources = copy.deepcopy(available_resources)
|
50 |
+
if isinstance(copy_task_decomposition, str):
|
51 |
+
copy_task_decomposition = json.loads(copy_task_decomposition)
|
52 |
+
|
53 |
+
for subtask in copy_task_decomposition:
|
54 |
+
for arg in subtask["args"]:
|
55 |
+
if arg["type"] in NON_FILE_TYPES:
|
56 |
+
continue
|
57 |
+
|
58 |
+
r_type = available_resources.get(arg["value"], "None").split(".")[-1]
|
59 |
+
if arg["value"] not in available_resources or arg["type"] != r_type:
|
60 |
+
new_value = self._find_latest_resource(
|
61 |
+
available_resources, arg["type"]
|
62 |
+
)
|
63 |
+
if new_value is None:
|
64 |
+
logger.error(
|
65 |
+
f"No available resource for {arg['value']} with type {arg['type']}"
|
66 |
+
)
|
67 |
+
return None
|
68 |
+
|
69 |
+
arg["value"] = new_value
|
70 |
+
|
71 |
+
available_resources[subtask["returns"][0]["value"]] = subtask["returns"][0][
|
72 |
+
"type"
|
73 |
+
]
|
74 |
+
return json.dumps(copy_task_decomposition, indent=2, ensure_ascii=False)
|
75 |
+
|
76 |
+
def wrap_request(self, request, memory):
|
77 |
+
logger.info(memory)
|
78 |
+
resource_list = {k: v.split(".")[-1] for k, v in memory.items()}
|
79 |
+
request = f"Resource list: {resource_list}\n{request}"
|
80 |
+
logger.info(f"Input: {request}")
|
81 |
+
return request
|
82 |
+
|
83 |
+
def solve_streaming(self, request: str, memory: dict = OrderedDict()):
|
84 |
+
request = self.wrap_request(request, memory)
|
85 |
+
sub_tasks = self.decomposer(request, streaming=self.streaming)
|
86 |
+
logger.info(f"Task decomposition: \n{sub_tasks}")
|
87 |
+
sub_tasks = self._check_task_decomposition(sub_tasks, memory)
|
88 |
+
yield sub_tasks
|
89 |
+
if sub_tasks in [None, "", []]:
|
90 |
+
yield None
|
91 |
+
else:
|
92 |
+
solutions = self.tog(request, sub_tasks, streaming=self.streaming)
|
93 |
+
yield solutions
|
94 |
+
|
95 |
+
def solve(self, request: str, memory: dict = OrderedDict()) -> List:
|
96 |
+
self.wrap_request(request, memory)
|
97 |
+
sub_tasks = self.decomposer(request)
|
98 |
+
solutions = self.tog(request, sub_tasks)
|
99 |
+
print(f"solutions: {solutions}")
|
100 |
+
return sub_tasks, solutions
|
101 |
+
|
102 |
+
def plan(self, task, memory: dict = OrderedDict()) -> List:
|
103 |
+
if self.streaming:
|
104 |
+
return self.solve_streaming(task, memory)
|
105 |
+
else:
|
106 |
+
return self.solve(task, memory)
|
107 |
+
|
108 |
+
def _check_solutions(self, solution: List | str) -> bool:
|
109 |
+
if isinstance(solution, str):
|
110 |
+
solution = json.loads(solution)
|
111 |
+
if len(solution) == 0:
|
112 |
+
return False
|
113 |
+
|
114 |
+
valid = True
|
115 |
+
for i, stage_candiate in enumerate(solution):
|
116 |
+
if len(stage_candiate) == 0:
|
117 |
+
logger.error(f"No solution is found in {i}-th subtask.")
|
118 |
+
valid = False
|
119 |
+
elif (
|
120 |
+
"solution" in stage_candiate[0]
|
121 |
+
and len(stage_candiate[0]["solution"]) == 0
|
122 |
+
):
|
123 |
+
logger.error(f"No solution is found in {i+1}-th subtask.")
|
124 |
+
valid = False
|
125 |
+
else:
|
126 |
+
logger.info(f"Solutions for {i+1}-th subtask:\n{stage_candiate}")
|
127 |
+
return valid
|
128 |
+
|
129 |
+
def parse(self, solution: List | str) -> List[List[Action]]:
|
130 |
+
if isinstance(solution, str):
|
131 |
+
solution = json.loads(solution)
|
132 |
+
|
133 |
+
if not self._check_solutions(solution):
|
134 |
+
return None
|
135 |
+
|
136 |
+
if isinstance(solution[0], Action):
|
137 |
+
return solution
|
138 |
+
|
139 |
+
stages = []
|
140 |
+
for i, stage_candiate in enumerate(solution):
|
141 |
+
stage = stage_candiate[0]["solution"]
|
142 |
+
actions = []
|
143 |
+
for action in stage:
|
144 |
+
inputs = {arg["name"]: arg["value"] for arg in action["args"]}
|
145 |
+
outputs = [r["value"] for r in action["returns"]]
|
146 |
+
actions.append(
|
147 |
+
Action(action["tool_name"], inputs=inputs, outputs=outputs)
|
148 |
+
)
|
149 |
+
stages.append(actions)
|
150 |
+
return stages
|
151 |
+
|
152 |
+
def __call__(
|
153 |
+
self, request: str, memory: dict = OrderedDict()
|
154 |
+
) -> List[List[Action]]:
|
155 |
+
solution = self.solve(request, memory)
|
156 |
+
return self.parse(solution)
|
cllm/agents/tog/responser.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
sys.path.append(os.getcwd())
|
7 |
+
from cllm.services.nlp.llms.chat_models import ChatOpenAI
|
8 |
+
|
9 |
+
# from cllm.services.nlp.llms.memory import MessageMemory
|
10 |
+
from langchain.schema import SystemMessage
|
11 |
+
|
12 |
+
from cllm.agents.base import DataType
|
13 |
+
from cllm.agents import container
|
14 |
+
|
15 |
+
|
16 |
+
RESPONSE_GENERATION_PROMPT = """Your name is ControlLLM, an AI-powered assistant developed by OpenGV-lab from Shanghai Artificial Intelligence Laboratory. For user's request, the system executes the solution and collects the results based on the following workflow. You need to respond to user requests based on the following information.
|
17 |
+
Here are the information for you reference.
|
18 |
+
|
19 |
+
## User Request
|
20 |
+
{}
|
21 |
+
|
22 |
+
## Workflow and Execution Results
|
23 |
+
{}
|
24 |
+
|
25 |
+
Now you should pay attention to Collected Results. You first must answer the user’s request in a straightforward manner. Then you need to summarize the workflow and intermediate results friendly. Some of the results may not be accurate and need you to use your judgement in making decisions. If the results contain file names, you have to output the file name directly. Only if there is nothing returned by tools, you should tell user you can not finish the task. Now, please friendly summarize the results and answer the question for the user requests `{}`.
|
26 |
+
""".strip()
|
27 |
+
|
28 |
+
|
29 |
+
SIMPLE_RESPONSE_GENERATION_PROMPT = """Your name is ControlLLM, an AI-powered assistant developed by OpenGVLab from Shanghai Artificial Intelligence Laboratory. You need to respond to user requests based on the following information.
|
30 |
+
Here are the information for you reference.
|
31 |
+
|
32 |
+
## User Request
|
33 |
+
{}
|
34 |
+
|
35 |
+
## Workflow and Execution Results
|
36 |
+
{}
|
37 |
+
|
38 |
+
Now, please friendly summarize the results and answer the question for the user requests `{}`.
|
39 |
+
""".strip()
|
40 |
+
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
def generate_response(user_input, solution, output_files):
|
45 |
+
if (
|
46 |
+
len(solution) <= 1
|
47 |
+
and len(solution[0]) <= 1
|
48 |
+
and solution[0][0].tool_name == "question_answering"
|
49 |
+
):
|
50 |
+
content = SIMPLE_RESPONSE_GENERATION_PROMPT.format(
|
51 |
+
user_input, solution, user_input
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
content = RESPONSE_GENERATION_PROMPT.format(user_input, solution, user_input)
|
55 |
+
|
56 |
+
logger.info("##### Response Generation #####")
|
57 |
+
logger.info(content)
|
58 |
+
|
59 |
+
chat = ChatOpenAI(model_name="gpt-3.5-turbo-1106")
|
60 |
+
messages = [SystemMessage(content=content)]
|
61 |
+
output = chat(messages)
|
62 |
+
logger.info(output)
|
63 |
+
|
64 |
+
# files = [output for output in output_files if isinstance(output, container.File)]
|
65 |
+
# return [container.Text('Response', DataType.TEXT, output)] + files
|
66 |
+
return [container.Text("Response", DataType.TEXT, output)]
|
cllm/services/audio/__init__.py
ADDED
File without changes
|
cllm/services/audio/api.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
import requests
|
5 |
+
|
6 |
+
from cllm.services.nlp.api import openai_chat_model
|
7 |
+
from cllm.services.utils import get_bytes_value
|
8 |
+
|
9 |
+
__ALL__ = [
|
10 |
+
"audio_classification",
|
11 |
+
"automatic_speech_recognition",
|
12 |
+
"text_to_speech",
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
17 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
18 |
+
|
19 |
+
|
20 |
+
def setup(host="localhost", port=10057):
|
21 |
+
global HOST, PORT
|
22 |
+
HOST = host
|
23 |
+
PORT = port
|
24 |
+
|
25 |
+
|
26 |
+
def audio_classification(audio, **kwargs):
|
27 |
+
host = kwargs.get("host", HOST)
|
28 |
+
port = kwargs.get("port", PORT)
|
29 |
+
url = f"http://{host}:{port}/audio_classification"
|
30 |
+
if isinstance(audio, str):
|
31 |
+
audio = open(audio, "rb").read()
|
32 |
+
files = {"audio": (audio, get_bytes_value(audio))}
|
33 |
+
response = requests.post(url, files=files)
|
34 |
+
return response.json()
|
35 |
+
|
36 |
+
|
37 |
+
def automatic_speech_recognition(audio: str, **kwargs):
|
38 |
+
host = kwargs.get("host", HOST)
|
39 |
+
port = kwargs.get("port", PORT)
|
40 |
+
url = f"http://{host}:{port}/automatic_speech_recognition"
|
41 |
+
# audio_file = open(audio, "rb")
|
42 |
+
files = {"audio": (audio, get_bytes_value(audio))}
|
43 |
+
response = requests.post(url, files=files)
|
44 |
+
return response.json()
|
45 |
+
|
46 |
+
|
47 |
+
def text_to_speech(text: str, **kwargs):
|
48 |
+
host = kwargs.get("host", HOST)
|
49 |
+
port = kwargs.get("port", PORT)
|
50 |
+
human_msg = f"""Your task is to extract the prompt from input. Here is examples:
|
51 |
+
|
52 |
+
Input:
|
53 |
+
translate the text into speech: \"Hope is the thing with feathers That perches in the soul, And sings the tune without the words, And never stops at all\"
|
54 |
+
|
55 |
+
Answer:
|
56 |
+
Hope is the thing with feathers That perches in the soul, And sings the tune without the words, And never stops at all
|
57 |
+
|
58 |
+
Input:
|
59 |
+
Can you help me transcribe the text into audio: I have a dream that one day this nation will rise up and live out the true meaning of its creed: We hold these truths to be self-evident, that all men are created equal.I have a dream that one day on the red hills of Georgia, the sons of former slaves and the sons of former slave owners will be able to sit down together at the table of brotherhood. I have a dream that one day even the state of Mississippi, a state sweltering with the heat of injustice, sweltering with the heat of oppression, will be transformed into an oasis of freedom and justice. I have a dream that my four little children will one day live in a nation where they will not be judged by the color of their skin but by the content of their character.
|
60 |
+
|
61 |
+
Answer:
|
62 |
+
I have a dream that one day this nation will rise up and live out the true meaning of its creed: We hold these truths to be self-evident, that all men are created equal.I have a dream that one day on the red hills of Georgia, the sons of former slaves and the sons of former slave owners will be able to sit down together at the table of brotherhood. I have a dream that one day even the state of Mississippi, a state sweltering with the heat of injustice, sweltering with the heat of oppression, will be transformed into an oasis of freedom and justice. I have a dream that my four little children will one day live in a nation where they will not be judged by the color of their skin but by the content of their character.
|
63 |
+
|
64 |
+
Input:
|
65 |
+
Create speech using the text: And so, my fellow Americans: ask not what your country can do for you — ask what you can do for your country.
|
66 |
+
|
67 |
+
Answer:
|
68 |
+
And so, my fellow Americans: ask not what your country can do for you — ask what you can do for your country.
|
69 |
+
|
70 |
+
Input:
|
71 |
+
The image features a large brown and white dog standing on a tree stump, accompanied by a small cat. The dog is positioned on the right side of the stump, while the cat is on the left side. Both animals appear to be looking at the camera, creating a captivating scene.\n\nThe dog and cat are the main focus of the image, with the dog being larger and more prominent, while the cat is smaller and positioned closer to the ground. The tree stump serves as a natural and interesting backdrop for the two animals, making the scene unique and engaging.
|
72 |
+
|
73 |
+
Answer:
|
74 |
+
The image features a large brown and white dog standing on a tree stump, accompanied by a small cat. The dog is positioned on the right side of the stump, while the cat is on the left side. Both animals appear to be looking at the camera, creating a captivating scene.\n\nThe dog and cat are the main focus of the image, with the dog being larger and more prominent, while the cat is smaller and positioned closer to the ground. The tree stump serves as a natural and interesting backdrop for the two animals, making the scene unique and engaging.
|
75 |
+
|
76 |
+
Input:
|
77 |
+
Life, thin and light-off time and time again\nFrivolous tireless\nI heard the echo, from the valleys and the heart\nOpen to the lonely soul of sickle harvesting\nRepeat outrightly, but also repeat the well-being of eventually swaying in the desert oasis\nI believe I am\nBorn as the bright summer flowers\nDo not withered undefeated fiery demon rule\nHeart rate and breathing to bear the load of the cumbersome Bored\nI heard the music, from the moon and carcass\nAuxiliary extreme aestheticism bait to capture misty\nFilling the intense life, but also filling the pure\nThere are always memories throughout the earth
|
78 |
+
|
79 |
+
Answer:
|
80 |
+
Life, thin and light-off time and time again\nFrivolous tireless\nI heard the echo, from the valleys and the heart\nOpen to the lonely soul of sickle harvesting\nRepeat outrightly, but also repeat the well-being of eventually swaying in the desert oasis\nI believe I am\nBorn as the bright summer flowers\nDo not withered undefeated fiery demon rule\nHeart rate and breathing to bear the load of the cumbersome Bored\nI heard the music, from the moon and carcass\nAuxiliary extreme aestheticism bait to capture misty\nFilling the intense life, but also filling the pure\nThere are always memories throughout the earth
|
81 |
+
|
82 |
+
Input:
|
83 |
+
{text}
|
84 |
+
|
85 |
+
Answer:
|
86 |
+
"""
|
87 |
+
extracted_prompt = openai_chat_model(human_msg)
|
88 |
+
print(f"extracted_prompt: {extracted_prompt}")
|
89 |
+
url = f"http://{host}:{port}/text_to_speech"
|
90 |
+
data = {"text": extracted_prompt}
|
91 |
+
response = requests.post(url, data=data)
|
92 |
+
return response.content
|
93 |
+
|
94 |
+
|
95 |
+
def text_to_music(text: str, **kwargs):
|
96 |
+
# print('a' * 40)
|
97 |
+
host = kwargs.get("host", HOST)
|
98 |
+
port = kwargs.get("port", PORT)
|
99 |
+
human_msg = f"""Your task is to extract the prompt from input. Here is examples:
|
100 |
+
|
101 |
+
Input:
|
102 |
+
Please generate a piece of music based on given prompt. Here is the prompt: An 80s driving pop song with heavy drums
|
103 |
+
|
104 |
+
Answer:
|
105 |
+
An 80s driving pop song with heavy drums
|
106 |
+
|
107 |
+
Input:
|
108 |
+
I would like you to provide me with a new song that represents an energetic and lively 80s pop track with prominent drums and synthesizer pads
|
109 |
+
|
110 |
+
Answer:
|
111 |
+
an energetic and lively 80s pop track with prominent drums and synthesizer pads
|
112 |
+
|
113 |
+
Input:
|
114 |
+
I'm looking for a song that has a driving pop vibe from the 80s, with heavy drums and synth pads playing in the background
|
115 |
+
|
116 |
+
Answer:
|
117 |
+
a driving pop vibe from the 80s, with heavy drums and synth pads playing in the background
|
118 |
+
|
119 |
+
Input:
|
120 |
+
Can you make a song that has a lively and energetic rhythm with prominent drums and electronic keyboard sounds in the background
|
121 |
+
|
122 |
+
Answer:
|
123 |
+
a lively and energetic rhythm with prominent drums and electronic keyboard sounds in the background
|
124 |
+
|
125 |
+
Input:
|
126 |
+
Can you make a piece of light and relaxing music
|
127 |
+
|
128 |
+
Answer:
|
129 |
+
a piece of light and relaxing music
|
130 |
+
|
131 |
+
Input:
|
132 |
+
{text}
|
133 |
+
|
134 |
+
Answer:
|
135 |
+
"""
|
136 |
+
extracted_prompt = openai_chat_model(human_msg)
|
137 |
+
url = f"http://{host}:{port}/text_to_music"
|
138 |
+
data = {"text": extracted_prompt}
|
139 |
+
response = requests.post(url, data=data)
|
140 |
+
return response.content
|
cllm/services/general/__init__.py
ADDED
File without changes
|
cllm/services/general/api.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from re import I
|
2 |
+
from typing import List
|
3 |
+
from pathlib import Path
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
|
7 |
+
__ALL__ = ["remote_logging", "select", "count"]
|
8 |
+
|
9 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
10 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
11 |
+
|
12 |
+
|
13 |
+
def setup(host="localhost", port=10056):
|
14 |
+
global HOST, PORT
|
15 |
+
HOST = host
|
16 |
+
PORT = port
|
17 |
+
|
18 |
+
|
19 |
+
def select(**kwargs):
|
20 |
+
if "bbox_list" in kwargs:
|
21 |
+
list = kwargs["bbox_list"]
|
22 |
+
condition = kwargs["condition"]
|
23 |
+
return [l for l in list if l["label"] == condition]
|
24 |
+
if "mask_list" in kwargs:
|
25 |
+
list = kwargs["mask_list"]
|
26 |
+
condition = kwargs["condition"]
|
27 |
+
# return combine_masks([l for l in list if l['label'] == condition])
|
28 |
+
return [l for l in list if l["label"] == condition]
|
29 |
+
if "category_list" in kwargs:
|
30 |
+
list = kwargs["category_list"]
|
31 |
+
condition = kwargs["condition"]
|
32 |
+
# return combine_masks([l for l in list if l['label'] == condition])
|
33 |
+
return [l for l in list if l["label"] == condition]
|
34 |
+
|
35 |
+
|
36 |
+
def count(**kwargs):
|
37 |
+
len_of_list = 0
|
38 |
+
if "bbox_list" in kwargs:
|
39 |
+
len_of_list = len(kwargs["bbox_list"])
|
40 |
+
elif "mask_list" in kwargs:
|
41 |
+
len_of_list = len(kwargs["mask_list"])
|
42 |
+
|
43 |
+
return f"The length of the given list is {len_of_list}"
|
44 |
+
|
45 |
+
|
46 |
+
def remote_logging(
|
47 |
+
history_msgs: list,
|
48 |
+
task_decomposition: list,
|
49 |
+
solution: list,
|
50 |
+
record: str,
|
51 |
+
like: bool,
|
52 |
+
**kwargs,
|
53 |
+
):
|
54 |
+
host = kwargs.get("host", HOST)
|
55 |
+
port = kwargs.get("port", PORT)
|
56 |
+
url = f"http://{host}:{port}/remote_logging"
|
57 |
+
data = {
|
58 |
+
"history_msgs": history_msgs,
|
59 |
+
"task_decomposition": task_decomposition,
|
60 |
+
"solution": solution,
|
61 |
+
"record": record,
|
62 |
+
"like": like,
|
63 |
+
}
|
64 |
+
response = requests.post(url, data=data)
|
65 |
+
return response.content
|
cllm/services/image_editing/__init__.py
ADDED
File without changes
|
cllm/services/image_editing/api.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
from PIL import Image, ImageDraw, ImageChops
|
5 |
+
import numpy as np
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
from typing import List, Union
|
9 |
+
from pathlib import Path
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
|
13 |
+
sys.path.append(os.getcwd())
|
14 |
+
from cllm.services.utils import get_bytes_value
|
15 |
+
from cllm.utils import get_real_path
|
16 |
+
from cllm.services.nlp.api import openai_chat_model
|
17 |
+
|
18 |
+
__ALL__ = [
|
19 |
+
"instruct_pix2pix",
|
20 |
+
"image_cropping",
|
21 |
+
"image_matting",
|
22 |
+
"draw_bbox_on_image",
|
23 |
+
"partial_image_editing",
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
28 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
29 |
+
|
30 |
+
|
31 |
+
def setup(host="localhost", port=10049):
|
32 |
+
global HOST, PORT
|
33 |
+
HOST = host
|
34 |
+
PORT = port
|
35 |
+
|
36 |
+
|
37 |
+
def image_cropping(image: str | Path, object: List[dict], **kwargs):
|
38 |
+
"""
|
39 |
+
bbox format: {'score': 0.997, 'label': 'bird', 'box': {'xmin': 69, 'ymin': 171, 'xmax': 396, 'ymax': 507}}
|
40 |
+
"""
|
41 |
+
if object in [None, b"", []]:
|
42 |
+
return None
|
43 |
+
|
44 |
+
if isinstance(image, (str, Path)):
|
45 |
+
image = Image.open(get_real_path(image)).convert("RGB")
|
46 |
+
elif isinstance(image, bytes):
|
47 |
+
image = Image.open(io.BytesIO(image)).convert("RGB")
|
48 |
+
w, h = image.size
|
49 |
+
cropped_images = []
|
50 |
+
for box in object:
|
51 |
+
box = copy.deepcopy(box["box"])
|
52 |
+
box = unify_bbox(box, w, h)
|
53 |
+
(left, upper, right, lower) = (
|
54 |
+
box["xmin"],
|
55 |
+
box["ymin"],
|
56 |
+
box["xmax"],
|
57 |
+
box["ymax"],
|
58 |
+
)
|
59 |
+
cropped_image = image.crop((left, upper, right, lower))
|
60 |
+
# cropped_image.save('test.png')
|
61 |
+
img_stream = io.BytesIO()
|
62 |
+
cropped_image.save(img_stream, format="png")
|
63 |
+
img_stream.seek(0)
|
64 |
+
cropped_images.append(img_stream.getvalue())
|
65 |
+
if len(cropped_images) == 0:
|
66 |
+
return None
|
67 |
+
return cropped_images
|
68 |
+
|
69 |
+
|
70 |
+
def image_matting(image: str | Path, mask: Union[str, bytes, List], **kwargs):
|
71 |
+
"""
|
72 |
+
{'score': 0.999025,
|
73 |
+
'label': 'person',
|
74 |
+
'mask': <PIL.Image.Image image mode=L size=386x384>}
|
75 |
+
"""
|
76 |
+
if mask in [None, b"", []]:
|
77 |
+
return None
|
78 |
+
image = Image.open(get_bytes_value(image)).convert("RGB")
|
79 |
+
|
80 |
+
mask = copy.deepcopy(mask)
|
81 |
+
if isinstance(mask, List):
|
82 |
+
mask_list = []
|
83 |
+
for m in mask:
|
84 |
+
if isinstance(m, dict):
|
85 |
+
mask_list.append(get_bytes_value(m["mask"]))
|
86 |
+
else:
|
87 |
+
mask_list.append(get_bytes_value(m))
|
88 |
+
mask = combine_masks(mask_list)
|
89 |
+
elif isinstance(mask, str):
|
90 |
+
mask = get_bytes_value(mask)
|
91 |
+
|
92 |
+
mask = Image.open(mask).convert("L")
|
93 |
+
|
94 |
+
mask = np.array(mask) > 0
|
95 |
+
image = np.array(image)
|
96 |
+
image = image * np.expand_dims(mask, -1)
|
97 |
+
img_stream = io.BytesIO()
|
98 |
+
image.save(img_stream, format="png")
|
99 |
+
img_stream.seek(0)
|
100 |
+
return img_stream.getvalue()
|
101 |
+
|
102 |
+
|
103 |
+
def unify_bbox(bbox, w, h):
|
104 |
+
bbox["xmin"] = (
|
105 |
+
bbox["xmin"] if isinstance(bbox["xmin"], int) else int(bbox["xmin"] * w)
|
106 |
+
)
|
107 |
+
|
108 |
+
bbox["ymin"] = (
|
109 |
+
bbox["ymin"] if isinstance(bbox["ymin"], int) else int(bbox["ymin"] * h)
|
110 |
+
)
|
111 |
+
bbox["xmax"] = (
|
112 |
+
bbox["xmax"] if isinstance(bbox["xmax"], int) else int(bbox["xmax"] * w)
|
113 |
+
)
|
114 |
+
bbox["ymax"] = (
|
115 |
+
bbox["ymax"] if isinstance(bbox["ymax"], int) else int(bbox["ymax"] * h)
|
116 |
+
)
|
117 |
+
return bbox
|
118 |
+
|
119 |
+
|
120 |
+
def draw_bbox_on_image(image: str | Path, bbox: list, **kwargs):
|
121 |
+
if isinstance(image, (str, Path)):
|
122 |
+
image = Image.open(get_real_path(image)).convert("RGB")
|
123 |
+
elif isinstance(image, bytes):
|
124 |
+
image = Image.open(io.BytesIO(image)).convert("RGB")
|
125 |
+
image = image.copy()
|
126 |
+
w, h = image.size
|
127 |
+
for box in bbox:
|
128 |
+
box = copy.deepcopy(box["box"])
|
129 |
+
box = unify_bbox(box, w, h)
|
130 |
+
(left, upper, right, lower) = (
|
131 |
+
box["xmin"],
|
132 |
+
box["ymin"],
|
133 |
+
box["xmax"],
|
134 |
+
box["ymax"],
|
135 |
+
)
|
136 |
+
draw = ImageDraw.Draw(image)
|
137 |
+
font_width = int(
|
138 |
+
min(box["xmax"] - box["xmin"], box["ymax"] - box["ymin"]) * 0.01
|
139 |
+
)
|
140 |
+
draw.rectangle(((left, upper), (right, lower)), outline="Red", width=font_width)
|
141 |
+
img_stream = io.BytesIO()
|
142 |
+
image.save(img_stream, format="png")
|
143 |
+
img_stream.seek(0)
|
144 |
+
# image = Image.save(image, format='png')
|
145 |
+
return img_stream.getvalue()
|
146 |
+
|
147 |
+
|
148 |
+
def _imagetext2image(image, text, endpoint, **kwargs):
|
149 |
+
host = kwargs.get("host", HOST)
|
150 |
+
port = kwargs.get("port", PORT)
|
151 |
+
url = f"http://{host}:{port}/{endpoint}"
|
152 |
+
data = {"text": text}
|
153 |
+
files = {"image": (image, get_bytes_value(image))}
|
154 |
+
response = requests.post(url, files=files, data=data)
|
155 |
+
return response.content
|
156 |
+
|
157 |
+
|
158 |
+
def instruct_pix2pix(image, text, **kwargs):
|
159 |
+
return _imagetext2image(image, text, endpoint="instruct_pix2pix", **kwargs)
|
160 |
+
|
161 |
+
|
162 |
+
def partial_image_editing(
|
163 |
+
image: str | bytes, mask: str | list | bytes, prompt: str, **kwargs
|
164 |
+
):
|
165 |
+
if mask in [None, b"", []]:
|
166 |
+
return None
|
167 |
+
|
168 |
+
host = kwargs.get("host", HOST)
|
169 |
+
port = kwargs.get("port", PORT)
|
170 |
+
url = f"http://{host}:{port}/partial_image_editing"
|
171 |
+
human_msg = f"""Your task is to extract the prompt from input. Here is examples:
|
172 |
+
|
173 |
+
Input:
|
174 |
+
Replace the masked object in the given image with a yellow horse
|
175 |
+
|
176 |
+
Answer:
|
177 |
+
a yellow horse
|
178 |
+
|
179 |
+
Input:
|
180 |
+
Use the c1s5af_mask.png in to replace the object with a man in the image
|
181 |
+
|
182 |
+
Answer:
|
183 |
+
a man
|
184 |
+
|
185 |
+
Input:
|
186 |
+
Modify the given image by replacing the object indicated in the mask with a bouquet of flowers
|
187 |
+
|
188 |
+
Answer:
|
189 |
+
with a bouquet of flowers
|
190 |
+
|
191 |
+
Input:
|
192 |
+
Use the 7a3c72_mask.png file to replace the object in the a9430b_image.png with a bus colored yellow and red with the number 5 on its front sign
|
193 |
+
|
194 |
+
Answer:
|
195 |
+
a bus colored yellow and red with the number 5 on its front sign.
|
196 |
+
|
197 |
+
Input:
|
198 |
+
Replace the masked area in image with a fat boy wearing a black jacket.
|
199 |
+
|
200 |
+
Answer:
|
201 |
+
a fat boy wearing a black jacket
|
202 |
+
|
203 |
+
Input:
|
204 |
+
{prompt}
|
205 |
+
|
206 |
+
Answer:
|
207 |
+
"""
|
208 |
+
extracted_prompt = openai_chat_model(human_msg)
|
209 |
+
data = {"prompt": extracted_prompt}
|
210 |
+
if isinstance(mask, List):
|
211 |
+
mask_list = []
|
212 |
+
for m in mask:
|
213 |
+
if isinstance(m, dict):
|
214 |
+
mask_list.append(get_bytes_value(m["mask"]))
|
215 |
+
else:
|
216 |
+
mask_list.append(get_bytes_value(m))
|
217 |
+
mask = combine_masks(mask_list)
|
218 |
+
|
219 |
+
files = {
|
220 |
+
"image": (image, get_bytes_value(image)),
|
221 |
+
"mask": ("mask", get_bytes_value(mask)),
|
222 |
+
}
|
223 |
+
response = requests.post(url, files=files, data=data)
|
224 |
+
return response.content
|
225 |
+
|
226 |
+
|
227 |
+
def combine_masks(mask_images):
|
228 |
+
if mask_images is None or len(mask_images) == 0:
|
229 |
+
return None
|
230 |
+
|
231 |
+
# Create a new blank image to store the combined mask
|
232 |
+
combined_mask = Image.open(io.BytesIO(mask_images[0])).convert("1")
|
233 |
+
|
234 |
+
# Iterate through each mask image and combine them
|
235 |
+
for mask_image in mask_images:
|
236 |
+
mask = Image.open(io.BytesIO(mask_image)).convert("1")
|
237 |
+
combined_mask = ImageChops.logical_or(combined_mask, mask)
|
238 |
+
stream = io.BytesIO()
|
239 |
+
combined_mask.save(stream, "png")
|
240 |
+
stream.seek(0)
|
241 |
+
# return {"label": mask_images[0]["label"], "mask": stream.getvalue()}
|
242 |
+
return stream.getvalue()
|
243 |
+
|
244 |
+
|
245 |
+
def inpainting_ldm_general(image, mask: Union[str, bytes, List], **kwargs):
|
246 |
+
if mask in [None, b"", []]:
|
247 |
+
return get_bytes_value(image)
|
248 |
+
|
249 |
+
mask = copy.deepcopy(mask)
|
250 |
+
if isinstance(mask, List):
|
251 |
+
mask_list = []
|
252 |
+
for m in mask:
|
253 |
+
if isinstance(m, dict):
|
254 |
+
mask_list.append(get_bytes_value(m["mask"]))
|
255 |
+
else:
|
256 |
+
mask_list.append(get_bytes_value(m))
|
257 |
+
mask = combine_masks(mask_list)
|
258 |
+
elif isinstance(mask, str):
|
259 |
+
mask = get_bytes_value(mask)
|
260 |
+
# mask = Image.open(mask).convert("1")
|
261 |
+
|
262 |
+
return inpainting_ldm(image, mask, **kwargs)
|
263 |
+
|
264 |
+
|
265 |
+
def inpainting_ldm(image, mask, **kwargs):
|
266 |
+
if mask in [None, b""]:
|
267 |
+
return get_bytes_value(image)
|
268 |
+
|
269 |
+
host = kwargs.get("host", HOST)
|
270 |
+
port = kwargs.get("port", PORT)
|
271 |
+
url = f"http://{host}:{port}/inpainting_ldm"
|
272 |
+
files = {
|
273 |
+
"image": (image, get_bytes_value(image)),
|
274 |
+
"mask": get_bytes_value(mask),
|
275 |
+
}
|
276 |
+
response = requests.post(url, files=files)
|
277 |
+
return response.content
|
cllm/services/image_generation/__init__.py
ADDED
File without changes
|
cllm/services/image_generation/api.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
import sys
|
7 |
+
|
8 |
+
sys.path.append(os.getcwd())
|
9 |
+
from PIL import Image
|
10 |
+
from cllm.services.utils import get_bytes_value
|
11 |
+
|
12 |
+
|
13 |
+
__ALL__ = [
|
14 |
+
"text2image",
|
15 |
+
"cannytext2image",
|
16 |
+
"linetext2image",
|
17 |
+
"hedtext2image",
|
18 |
+
"scribbletext2image",
|
19 |
+
"posetext2image",
|
20 |
+
"segtext2image",
|
21 |
+
"depthtext2image",
|
22 |
+
"normaltext2image" "image2image",
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
27 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
28 |
+
|
29 |
+
|
30 |
+
def setup(host="localhost", port=10049):
|
31 |
+
global HOST, PORT
|
32 |
+
HOST = host
|
33 |
+
PORT = port
|
34 |
+
|
35 |
+
|
36 |
+
def text2image(text, **kwargs):
|
37 |
+
host = kwargs.get("host", HOST)
|
38 |
+
port = kwargs.get("port", PORT)
|
39 |
+
url = f"http://{host}:{port}/text2image"
|
40 |
+
data = {"text": text}
|
41 |
+
response = requests.post(url, data=data)
|
42 |
+
return response.content
|
43 |
+
|
44 |
+
|
45 |
+
def image2image(image, **kwargs):
|
46 |
+
host = kwargs.get("host", HOST)
|
47 |
+
port = kwargs.get("port", PORT)
|
48 |
+
url = f"http://{host}:{port}/image2image"
|
49 |
+
files = {"image": (image, get_bytes_value(image))}
|
50 |
+
response = requests.post(url, files=files)
|
51 |
+
return response.content
|
52 |
+
|
53 |
+
|
54 |
+
def _imagetext2image(image, text, endpoint, **kwargs):
|
55 |
+
host = kwargs.get("host", HOST)
|
56 |
+
port = kwargs.get("port", PORT)
|
57 |
+
url = f"http://{host}:{port}/{endpoint}"
|
58 |
+
data = {"text": text}
|
59 |
+
files = {"image": (image, get_bytes_value(image))}
|
60 |
+
response = requests.post(url, files=files, data=data)
|
61 |
+
# image = Image.open(io.BytesIO(response.content))
|
62 |
+
# image = io.BytesIO(response.content)
|
63 |
+
# return image
|
64 |
+
return response.content
|
65 |
+
|
66 |
+
|
67 |
+
def cannytext2image(edge, text, **kwargs):
|
68 |
+
return _imagetext2image(edge, text, endpoint="cannytext2image", **kwargs)
|
69 |
+
|
70 |
+
|
71 |
+
def linetext2image(line, text, **kwargs):
|
72 |
+
return _imagetext2image(line, text, endpoint="linetext2image", **kwargs)
|
73 |
+
|
74 |
+
|
75 |
+
def hedtext2image(hed, text, **kwargs):
|
76 |
+
return _imagetext2image(hed, text, endpoint="hedtext2image", **kwargs)
|
77 |
+
|
78 |
+
|
79 |
+
def scribbletext2image(scribble, text, **kwargs):
|
80 |
+
return _imagetext2image(scribble, text, endpoint="scribbletext2image", **kwargs)
|
81 |
+
|
82 |
+
|
83 |
+
def posetext2image(pose, text, **kwargs):
|
84 |
+
return _imagetext2image(pose, text, endpoint="posetext2image", **kwargs)
|
85 |
+
|
86 |
+
|
87 |
+
def segtext2image(segmentation, text, **kwargs):
|
88 |
+
return _imagetext2image(segmentation, text, endpoint="segtext2image", **kwargs)
|
89 |
+
|
90 |
+
|
91 |
+
def depthtext2image(depth, text, **kwargs):
|
92 |
+
return _imagetext2image(depth, text, endpoint="depthtext2image", **kwargs)
|
93 |
+
|
94 |
+
|
95 |
+
def normaltext2image(normal, text, **kwargs):
|
96 |
+
return _imagetext2image(normal, text, endpoint="normaltext2image", **kwargs)
|
cllm/services/image_inpainting/__init__.py
ADDED
File without changes
|
cllm/services/image_inpainting/api.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Union, List, Dict
|
3 |
+
from PIL import Image, ImageChops
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
|
7 |
+
import requests
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
sys.path.append(os.getcwd())
|
12 |
+
from cllm.servcies.utils import get_bytes_value
|
13 |
+
|
14 |
+
__ALL__ = [
|
15 |
+
"inpainting_ldm",
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
20 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
21 |
+
|
22 |
+
|
23 |
+
def setup(host="localhost", port=10052):
|
24 |
+
global HOST, PORT
|
25 |
+
HOST = host
|
26 |
+
PORT = port
|
27 |
+
|
28 |
+
|
29 |
+
def combine_masks(mask_images):
|
30 |
+
if mask_images is None or len(mask_images) == 0:
|
31 |
+
return None
|
32 |
+
|
33 |
+
# Create a new blank image to store the combined mask
|
34 |
+
combined_mask = Image.open(io.BytesIO(mask_images[0])).convert("1")
|
35 |
+
|
36 |
+
# Iterate through each mask image and combine them
|
37 |
+
for mask_image in mask_images:
|
38 |
+
mask = Image.open(io.BytesIO(mask_image)).convert("1")
|
39 |
+
combined_mask = ImageChops.logical_or(combined_mask, mask)
|
40 |
+
stream = io.BytesIO()
|
41 |
+
combined_mask.save(stream, "png")
|
42 |
+
stream.seek(0)
|
43 |
+
# return {"label": mask_images[0]["label"], "mask": stream.getvalue()}
|
44 |
+
return stream.getvalue()
|
45 |
+
|
46 |
+
|
47 |
+
def inpainting_ldm_general(image, mask: Union[bytes, List], **kwargs):
|
48 |
+
if mask in [None, b"", []]:
|
49 |
+
return get_bytes_value(image)
|
50 |
+
|
51 |
+
mask = copy.deepcopy(mask)
|
52 |
+
if isinstance(mask, List):
|
53 |
+
if not isinstance(mask[0], dict):
|
54 |
+
mask_list = get_bytes_value(mask)
|
55 |
+
else:
|
56 |
+
mask_list = []
|
57 |
+
for m in mask:
|
58 |
+
mask_list.append(get_bytes_value(m["mask"]))
|
59 |
+
mask = combine_masks(mask_list)
|
60 |
+
|
61 |
+
return inpainting_ldm(image, mask, **kwargs)
|
62 |
+
|
63 |
+
|
64 |
+
def inpainting_ldm(image, mask, **kwargs):
|
65 |
+
if mask in [None, b""]:
|
66 |
+
return get_bytes_value(image)
|
67 |
+
|
68 |
+
host = kwargs.get("host", HOST)
|
69 |
+
port = kwargs.get("port", PORT)
|
70 |
+
url = f"http://{host}:{port}/inpainting_ldm"
|
71 |
+
files = {
|
72 |
+
"image": (image, get_bytes_value(image)),
|
73 |
+
"mask": get_bytes_value(mask),
|
74 |
+
}
|
75 |
+
response = requests.post(url, files=files)
|
76 |
+
return response.content
|
cllm/services/image_perception/__init__.py
ADDED
File without changes
|
cllm/services/image_perception/api.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import codecs
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
from pathlib import Path
|
6 |
+
from PIL import Image
|
7 |
+
import requests
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
sys.path.append(os.getcwd())
|
12 |
+
from cllm.services.utils import get_bytes_value
|
13 |
+
from cllm.services.nlp.api import openai_chat_model
|
14 |
+
|
15 |
+
__ALL__ = [
|
16 |
+
"object_detection",
|
17 |
+
"image_classification",
|
18 |
+
"ocr",
|
19 |
+
"image_to_text",
|
20 |
+
"segment_objects",
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
25 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
26 |
+
|
27 |
+
|
28 |
+
def setup(host="localhost", port=10049):
|
29 |
+
global HOST, PORT
|
30 |
+
HOST = host
|
31 |
+
PORT = port
|
32 |
+
|
33 |
+
|
34 |
+
def object_detection(image, **kwargs):
|
35 |
+
host = kwargs.get("host", HOST)
|
36 |
+
port = kwargs.get("port", PORT)
|
37 |
+
url = f"http://{host}:{port}/object_detection"
|
38 |
+
files = {"image": (image, get_bytes_value(image))}
|
39 |
+
response = requests.post(url, files=files)
|
40 |
+
return response.json()
|
41 |
+
|
42 |
+
|
43 |
+
def image_classification(image, **kwargs):
|
44 |
+
host = kwargs.get("host", HOST)
|
45 |
+
port = kwargs.get("port", PORT)
|
46 |
+
url = f"http://{host}:{port}/image_classification"
|
47 |
+
files = {"image": (image, get_bytes_value(image))}
|
48 |
+
response = requests.post(url, files=files)
|
49 |
+
return response.json()
|
50 |
+
|
51 |
+
|
52 |
+
def image_to_text(image, **kwargs):
|
53 |
+
host = kwargs.get("host", HOST)
|
54 |
+
port = kwargs.get("port", PORT)
|
55 |
+
url = f"http://{host}:{port}/image_to_text"
|
56 |
+
files = {"image": (image, get_bytes_value(image))}
|
57 |
+
response = requests.post(url, files=files)
|
58 |
+
return response.json()
|
59 |
+
|
60 |
+
|
61 |
+
def ocr(image, **kwargs):
|
62 |
+
host = kwargs.get("host", HOST)
|
63 |
+
port = kwargs.get("port", PORT)
|
64 |
+
url = f"http://{host}:{port}/ocr"
|
65 |
+
files = {"image": (image, get_bytes_value(image))}
|
66 |
+
response = requests.post(url, files=files)
|
67 |
+
return response.json()
|
68 |
+
|
69 |
+
|
70 |
+
def segment_objects(image, **kwargs):
|
71 |
+
host = kwargs.get("host", HOST)
|
72 |
+
port = kwargs.get("port", PORT)
|
73 |
+
url = f"http://{host}:{port}/segment_objects"
|
74 |
+
files = {"image": (image, get_bytes_value(image))}
|
75 |
+
response = requests.post(url, files=files)
|
76 |
+
pickled = response.json()["data"]
|
77 |
+
output = pickle.loads(codecs.decode(pickled.encode(), "base64"))
|
78 |
+
for o in output:
|
79 |
+
stream = io.BytesIO()
|
80 |
+
o["mask"].save(stream, format="png")
|
81 |
+
stream.seek(0)
|
82 |
+
o["mask"] = stream.getvalue()
|
83 |
+
|
84 |
+
return output
|
85 |
+
|
86 |
+
|
87 |
+
def visual_grounding(image, query, **kwargs):
|
88 |
+
host = kwargs.get("host", HOST)
|
89 |
+
port = kwargs.get("port", PORT)
|
90 |
+
url = rf"http://{host}:{port}/visual_grounding"
|
91 |
+
human_msg = f"""Your task is to extract the prompt from input. Here is examples:
|
92 |
+
|
93 |
+
Input:
|
94 |
+
find the regin of interest in the da9619_image.png: \"An elephant in right corner\"
|
95 |
+
|
96 |
+
Answer:
|
97 |
+
An elephant in right corner
|
98 |
+
|
99 |
+
Input:
|
100 |
+
locate \"A maintenance vehicle on a railway\" in the image
|
101 |
+
|
102 |
+
Answer:
|
103 |
+
A maintenance vehicle on a railway
|
104 |
+
|
105 |
+
Input:
|
106 |
+
use visual grounding method to detect the regin of interest in the 1ba6e2_image.png: The motorcycle with the rainbow flag"
|
107 |
+
|
108 |
+
Answer:
|
109 |
+
The motorcycle with the rainbow flag
|
110 |
+
|
111 |
+
Input:
|
112 |
+
for given image, find A little baby girl with brunette hair, a pink and white dress, and is being fed frosting from her mom."
|
113 |
+
|
114 |
+
Answer:
|
115 |
+
A little baby girl with brunette hair, a pink and white dress, and is being fed frosting from her mom
|
116 |
+
|
117 |
+
Input:
|
118 |
+
find the policeman on the motorcycle in the 851522_image.png"
|
119 |
+
|
120 |
+
Answer:
|
121 |
+
the policeman on the motorcycle
|
122 |
+
|
123 |
+
Input:
|
124 |
+
The legs of a zebra shown under the neck of another zebra.
|
125 |
+
|
126 |
+
Answer:
|
127 |
+
The legs of a zebra shown under the neck of another zebra.
|
128 |
+
|
129 |
+
Input:
|
130 |
+
{query}
|
131 |
+
|
132 |
+
Answer:
|
133 |
+
"""
|
134 |
+
|
135 |
+
extracted_prompt = openai_chat_model(human_msg)
|
136 |
+
files = {"image": get_bytes_value(image)}
|
137 |
+
data = {"query": extracted_prompt}
|
138 |
+
# image = Image.open(io.BytesIO(image)).convert("RGB")
|
139 |
+
response = requests.post(url, data=data, files=files)
|
140 |
+
|
141 |
+
return response.json()
|
142 |
+
|
143 |
+
|
144 |
+
def image_captioning(image, endpoint="llava", **kwargs):
|
145 |
+
host = kwargs.get("host", HOST)
|
146 |
+
port = kwargs.get("port", PORT)
|
147 |
+
url = f"http://{host}:{port}/{endpoint}"
|
148 |
+
data = None
|
149 |
+
if endpoint == "llava":
|
150 |
+
data = {"text": "Please describe the image in details."}
|
151 |
+
files = {"image": (image, get_bytes_value(image))}
|
152 |
+
response = requests.post(url, files=files, data=data)
|
153 |
+
return response.content.decode("utf-8")
|
154 |
+
|
155 |
+
|
156 |
+
def segment_all(image: str | Path, **kwargs):
|
157 |
+
host = kwargs.get("host", HOST)
|
158 |
+
port = kwargs.get("port", PORT)
|
159 |
+
url = f"http://{host}:{port}/segment_all"
|
160 |
+
files = {"image": (image, get_bytes_value(image))}
|
161 |
+
response = requests.post(url, files=files)
|
162 |
+
return response.content
|
163 |
+
|
164 |
+
|
165 |
+
def set_image(image: str | Path, **kwargs):
|
166 |
+
host = kwargs.get("host", HOST)
|
167 |
+
port = kwargs.get("port", PORT)
|
168 |
+
url = f"http://{host}:{port}/set_image"
|
169 |
+
files = {"image": (image, get_bytes_value(image))}
|
170 |
+
response = requests.post(url, files=files)
|
171 |
+
return response.content.decode()
|
172 |
+
|
173 |
+
|
174 |
+
def segment_by_mask(mask: str | Path, image_id: str, **kwargs):
|
175 |
+
host = kwargs.get("host", HOST)
|
176 |
+
port = kwargs.get("port", PORT)
|
177 |
+
url = f"http://{host}:{port}/segment_by_mask"
|
178 |
+
data = {"image_id": image_id}
|
179 |
+
files = {"mask": (mask, get_bytes_value(mask))}
|
180 |
+
response = requests.post(url, files=files, data=data)
|
181 |
+
return response.content
|
182 |
+
|
183 |
+
|
184 |
+
def segment_by_points(points: list | tuple | str, image_id: str, **kwargs):
|
185 |
+
host = kwargs.get("host", HOST)
|
186 |
+
port = kwargs.get("port", PORT)
|
187 |
+
url = f"http://{host}:{port}/segment_by_points"
|
188 |
+
data = {"points": points, "image_id": image_id}
|
189 |
+
response = requests.post(url, data=data)
|
190 |
+
return response.content
|
191 |
+
|
192 |
+
|
193 |
+
def seg_by_mask(image, prompt_mask, **kwargs):
|
194 |
+
image_id = set_image(image)
|
195 |
+
mask = segment_by_mask(mask=prompt_mask, image_id=image_id)
|
196 |
+
return mask
|
197 |
+
|
198 |
+
|
199 |
+
def seg_by_points(image, prompt_points, **kwargs):
|
200 |
+
image_id = set_image(image)
|
201 |
+
mask = segment_by_points(points=prompt_points, image_id=image_id)
|
202 |
+
return mask
|
cllm/services/image_processing/__init__.py
ADDED
File without changes
|
cllm/services/image_processing/api.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
from cllm.services.utils import get_bytes_value
|
7 |
+
|
8 |
+
__ALL__ = [
|
9 |
+
"image2canny",
|
10 |
+
"image2line",
|
11 |
+
"image2hed",
|
12 |
+
"image2scribble",
|
13 |
+
"image2pose",
|
14 |
+
"image2depth",
|
15 |
+
"image2normal",
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
20 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
21 |
+
|
22 |
+
|
23 |
+
def setup(host="localhost", port=10049):
|
24 |
+
global HOST, PORT
|
25 |
+
HOST = host
|
26 |
+
PORT = port
|
27 |
+
|
28 |
+
|
29 |
+
def image2anything(image: Image, endpoint="image2line", **kwargs):
|
30 |
+
host = kwargs.get("host", HOST)
|
31 |
+
port = kwargs.get("port", PORT)
|
32 |
+
url = f"http://{host}:{port}/{endpoint}"
|
33 |
+
files = {"image": (image, get_bytes_value(image))}
|
34 |
+
response = requests.post(url, files=files)
|
35 |
+
return response.content
|
36 |
+
|
37 |
+
|
38 |
+
def image2canny(image: Image, **kwargs):
|
39 |
+
return image2anything(image, endpoint="image2canny", **kwargs)
|
40 |
+
|
41 |
+
|
42 |
+
def image2line(image: Image, **kwargs):
|
43 |
+
return image2anything(image, endpoint="image2line", **kwargs)
|
44 |
+
|
45 |
+
|
46 |
+
def image2hed(image: Image, **kwargs):
|
47 |
+
return image2anything(image, endpoint="image2hed", **kwargs)
|
48 |
+
|
49 |
+
|
50 |
+
def image2scribble(image: Image, **kwargs):
|
51 |
+
return image2anything(image, endpoint="image2scribble", **kwargs)
|
52 |
+
|
53 |
+
|
54 |
+
def image2pose(image: Image, **kwargs):
|
55 |
+
return image2anything(image, endpoint="image2pose", **kwargs)
|
56 |
+
|
57 |
+
|
58 |
+
def image2depth(image: Image, **kwargs):
|
59 |
+
return image2anything(image, endpoint="image2depth", **kwargs)
|
60 |
+
|
61 |
+
|
62 |
+
def image2normal(image: Image, **kwargs):
|
63 |
+
return image2anything(image, endpoint="image2normal", **kwargs)
|
cllm/services/nlp/__init__.py
ADDED
File without changes
|
cllm/services/nlp/api.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
import requests
|
6 |
+
import json
|
7 |
+
from .llms.chat_models import ChatOpenAI
|
8 |
+
from langchain.schema import (
|
9 |
+
HumanMessage,
|
10 |
+
SystemMessage,
|
11 |
+
AIMessage,
|
12 |
+
)
|
13 |
+
from typing import (
|
14 |
+
TYPE_CHECKING,
|
15 |
+
Any,
|
16 |
+
AsyncIterator,
|
17 |
+
Callable,
|
18 |
+
Dict,
|
19 |
+
Iterator,
|
20 |
+
List,
|
21 |
+
Mapping,
|
22 |
+
Optional,
|
23 |
+
Tuple,
|
24 |
+
Type,
|
25 |
+
Union,
|
26 |
+
)
|
27 |
+
|
28 |
+
__ALL__ = [
|
29 |
+
"text_to_text_generation",
|
30 |
+
"title_generation",
|
31 |
+
"text_to_tags",
|
32 |
+
"question_answering",
|
33 |
+
"summarization",
|
34 |
+
]
|
35 |
+
|
36 |
+
|
37 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
38 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
39 |
+
|
40 |
+
|
41 |
+
def setup(host="localhost", port=10056):
|
42 |
+
global HOST, PORT
|
43 |
+
HOST = host
|
44 |
+
PORT = port
|
45 |
+
|
46 |
+
|
47 |
+
def text_to_text_generation(text: str, **kwargs):
|
48 |
+
host = kwargs.get("host", HOST)
|
49 |
+
port = kwargs.get("port", PORT)
|
50 |
+
url = f"http://{host}:{port}/text_to_text_generation"
|
51 |
+
data = {"text": text}
|
52 |
+
response = requests.post(url, data=data)
|
53 |
+
return response.json()
|
54 |
+
|
55 |
+
|
56 |
+
def question_answering_with_context(context: str, question: str, **kwargs):
|
57 |
+
host = kwargs.get("host", HOST)
|
58 |
+
port = kwargs.get("port", PORT)
|
59 |
+
url = f"http://{host}:{port}/question_answering_with_context"
|
60 |
+
data = {"context": context, "question": question}
|
61 |
+
response = requests.post(url, data=data)
|
62 |
+
return response.json()
|
63 |
+
|
64 |
+
|
65 |
+
def openai_chat_model(input_msg: str, **kwargs):
|
66 |
+
chat = ChatOpenAI()
|
67 |
+
chat_log = []
|
68 |
+
default_sys_msg = "Your name is ControlLLM, an AI-powered assistant developed by OpenGVLab from Shanghai AI Lab. You need to respond to user requests based on the following information."
|
69 |
+
sys_msg = kwargs.get("sys_msg", default_sys_msg)
|
70 |
+
if sys_msg is not None:
|
71 |
+
chat_log.append(SystemMessage(content=sys_msg))
|
72 |
+
# history_msgs: list[str]
|
73 |
+
history_msgs = []
|
74 |
+
if "history_msgs" in kwargs:
|
75 |
+
history_msgs = kwargs.get("history_msgs", [])
|
76 |
+
|
77 |
+
for item in history_msgs:
|
78 |
+
if isinstance(item[0], (list, tuple)):
|
79 |
+
item[0] = "Received file: " + item[0][0]
|
80 |
+
if isinstance(item[1], (list, tuple)):
|
81 |
+
item[1] = "Generated file: " + item[1][0]
|
82 |
+
if item[0] is not None:
|
83 |
+
chat_log.append(HumanMessage(content=item[0]))
|
84 |
+
if item[1] is not None:
|
85 |
+
chat_log.append(AIMessage(content=item[1]))
|
86 |
+
# chat_log.extend([HumanMessage(content=item[0]), AIMessage(content=item[1])])
|
87 |
+
if not isinstance(input_msg, str):
|
88 |
+
input_msg = json.dumps(input_msg, ensure_ascii=False)
|
89 |
+
output = chat(chat_log + [HumanMessage(content=input_msg)])
|
90 |
+
return output
|
91 |
+
|
92 |
+
|
93 |
+
def title_generation(text: str, **kwargs):
|
94 |
+
question = "summarize"
|
95 |
+
response = question_answering_with_context(text, question)
|
96 |
+
return response
|
97 |
+
|
98 |
+
|
99 |
+
def summarization(text: str, **kwargs):
|
100 |
+
host = kwargs.get("host", HOST)
|
101 |
+
port = kwargs.get("port", PORT)
|
102 |
+
url = f"http://{host}:{port}/summarization"
|
103 |
+
data = {"text": text}
|
104 |
+
response = requests.post(url, data=data)
|
105 |
+
return response.json()
|
106 |
+
|
107 |
+
|
108 |
+
def text_to_tags(text: str, **kwargs):
|
109 |
+
host = kwargs.get("host", HOST)
|
110 |
+
port = kwargs.get("port", PORT)
|
111 |
+
url = f"http://{host}:{port}/text_to_tags"
|
112 |
+
data = {"text": text}
|
113 |
+
response = requests.post(url, data=data)
|
114 |
+
return response.json()
|
115 |
+
|
116 |
+
|
117 |
+
def get_time(location: str = None, **kwargs):
|
118 |
+
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
119 |
+
|
120 |
+
|
121 |
+
def get_weather(location: str | list, **kwargs):
|
122 |
+
host = kwargs.get("host", HOST)
|
123 |
+
port = kwargs.get("port", PORT)
|
124 |
+
url = f"http://{host}:{port}/get_weather"
|
125 |
+
if isinstance(location, list):
|
126 |
+
t = {"CITY": "", "COUNTRY": ""}
|
127 |
+
for l in location:
|
128 |
+
if l["entity_group"] not in t.keys():
|
129 |
+
continue
|
130 |
+
if t[l["entity_group"]] == "":
|
131 |
+
t[l["entity_group"]] = l["word"].title()
|
132 |
+
location = ",".join([t["CITY"], t["COUNTRY"]])
|
133 |
+
|
134 |
+
data = {"location": location}
|
135 |
+
response = requests.post(url, data=data)
|
136 |
+
return response.json()
|
137 |
+
|
138 |
+
|
139 |
+
def summarize_weather_condition(weather: str | list, **kwargs):
|
140 |
+
if isinstance(weather, list):
|
141 |
+
weather = json.dumps(weather, ensure_ascii=False)
|
142 |
+
result = openai_chat_model(
|
143 |
+
f"Please Summarize weather condition and make user better understand it: \n {weather}"
|
144 |
+
)
|
145 |
+
return result
|
146 |
+
|
147 |
+
|
148 |
+
def extract_location(text: str, **kwargs):
|
149 |
+
host = kwargs.get("host", HOST)
|
150 |
+
port = kwargs.get("port", PORT)
|
151 |
+
url = f"http://{host}:{port}/extract_location"
|
152 |
+
data = {"text": text}
|
153 |
+
response = requests.post(url, data=data)
|
154 |
+
return response.json()
|
155 |
+
|
156 |
+
|
157 |
+
def sentiment_analysis(text: str, **kwargs):
|
158 |
+
host = kwargs.get("host", HOST)
|
159 |
+
port = kwargs.get("port", PORT)
|
160 |
+
url = f"http://{host}:{port}/sentiment_analysis"
|
161 |
+
data = {"text": text}
|
162 |
+
response = requests.post(url, data=data)
|
163 |
+
return response.json()
|
cllm/services/nlp/llms/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .chat_models import ChatOpenAI
|
2 |
+
from .memory import MessageMemory
|
cllm/services/nlp/llms/chat_models.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import openai
|
3 |
+
import requests
|
4 |
+
from typing import (
|
5 |
+
Any,
|
6 |
+
Dict,
|
7 |
+
List,
|
8 |
+
Optional,
|
9 |
+
)
|
10 |
+
from langchain.schema import (
|
11 |
+
AIMessage,
|
12 |
+
BaseMessage,
|
13 |
+
ChatMessage,
|
14 |
+
HumanMessage,
|
15 |
+
SystemMessage,
|
16 |
+
)
|
17 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
18 |
+
from langchain.chat_models.base import SimpleChatModel
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
|
22 |
+
sys.path.append(os.getcwd())
|
23 |
+
|
24 |
+
from cllm.services.nlp.llms.memory import MessageMemory
|
25 |
+
from cllm.utils import timeout
|
26 |
+
|
27 |
+
|
28 |
+
class ChatOpenAI:
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
model_name: str = "gpt-3.5-turbo",
|
32 |
+
temperature: float = 0.7,
|
33 |
+
model_kwargs: Dict[str, Any] = dict(),
|
34 |
+
openai_api_key: Optional[str] = None,
|
35 |
+
openai_base_url: Optional[str] = None,
|
36 |
+
) -> None:
|
37 |
+
self.model_name = model_name
|
38 |
+
self.temperature = temperature
|
39 |
+
self.model_kwargs = model_kwargs
|
40 |
+
self.api_key = os.environ.get("OPENAI_API_KEY", openai_api_key)
|
41 |
+
self.base_url = os.environ.get("OPENAI_BASE_URL", openai_base_url)
|
42 |
+
|
43 |
+
def __call__(self, messages: List[BaseMessage], **kwargs):
|
44 |
+
stream = kwargs.get("stream", False)
|
45 |
+
context = MessageMemory(messages=messages)
|
46 |
+
context.cut_memory(self.model_name)
|
47 |
+
response = self.send_message(messages=context.to_dict(), stream=stream)
|
48 |
+
return response
|
49 |
+
|
50 |
+
def get_response(self, response):
|
51 |
+
return response.choices[0].message.content
|
52 |
+
|
53 |
+
def send_message(self, messages, stream=False):
|
54 |
+
cnt = 10
|
55 |
+
while cnt > 0:
|
56 |
+
try:
|
57 |
+
result = self.get_response(
|
58 |
+
self._send_message(
|
59 |
+
model=self.model_name,
|
60 |
+
messages=messages,
|
61 |
+
temperature=self.temperature,
|
62 |
+
stream=stream,
|
63 |
+
timeout=5,
|
64 |
+
)
|
65 |
+
)
|
66 |
+
break
|
67 |
+
except Exception as e:
|
68 |
+
cnt -= 1
|
69 |
+
print(e)
|
70 |
+
result = e
|
71 |
+
return result
|
72 |
+
|
73 |
+
# @timeout(5)
|
74 |
+
def _send_message(self, *args, **kwargs):
|
75 |
+
# return self.client.chat.completions.create(*args, **kwargs)
|
76 |
+
# return openai.Completion.create(*args, **kwargs)
|
77 |
+
return openai.chat.completions.create(*args, **kwargs)
|
78 |
+
|
79 |
+
|
80 |
+
class ChatLLAMA2(SimpleChatModel):
|
81 |
+
"""Wrapper around LLAMA2
|
82 |
+
|
83 |
+
To use, you should launch you local model as web services.
|
84 |
+
"""
|
85 |
+
|
86 |
+
client: Any = None #: :meta private:
|
87 |
+
endpoint: str = "http://localhost:10051"
|
88 |
+
|
89 |
+
HUMAN_PROMPT = "user"
|
90 |
+
AI_PROMPT = "assistant"
|
91 |
+
|
92 |
+
@property
|
93 |
+
def _llm_type(self) -> str:
|
94 |
+
"""Return type of chat model."""
|
95 |
+
return "local-chat"
|
96 |
+
|
97 |
+
def _call(
|
98 |
+
self,
|
99 |
+
messages: List[BaseMessage],
|
100 |
+
stop: Optional[List[str]] = None,
|
101 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
102 |
+
) -> str:
|
103 |
+
data = self._convert_messages_to_prompt(messages)
|
104 |
+
response = requests.post(self.endpoint, json=data)
|
105 |
+
return response.content.decode()
|
106 |
+
|
107 |
+
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
108 |
+
if isinstance(message, ChatMessage):
|
109 |
+
message_text = {
|
110 |
+
"role": message.role.capitalize(),
|
111 |
+
"content": message.content,
|
112 |
+
}
|
113 |
+
elif isinstance(message, HumanMessage):
|
114 |
+
message_text = {"role": self.HUMAN_PROMPT, "content": message.content}
|
115 |
+
elif isinstance(message, AIMessage):
|
116 |
+
message_text = {"role": self.AI_PROMPT, "content": message.content}
|
117 |
+
elif isinstance(message, SystemMessage):
|
118 |
+
message_text = {"role": "system", "content": message.content}
|
119 |
+
else:
|
120 |
+
raise ValueError(f"Got unknown type {message}")
|
121 |
+
return message_text
|
122 |
+
|
123 |
+
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
|
124 |
+
"""Format a list of strings into a single string with necessary newlines.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
messages (List[BaseMessage]): List of BaseMessage to combine.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
str: Combined string with necessary newlines.
|
131 |
+
"""
|
132 |
+
return [self._convert_one_message_to_text(message) for message in messages]
|
133 |
+
|
134 |
+
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
|
135 |
+
"""Format a list of messages into a full prompt for the Anthropic model
|
136 |
+
|
137 |
+
Args:
|
138 |
+
messages (List[BaseMessage]): List of BaseMessage to combine.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
|
142 |
+
"""
|
143 |
+
return self._convert_messages_to_text(messages)
|
144 |
+
|
145 |
+
|
146 |
+
class ChatLLAMA2(SimpleChatModel):
|
147 |
+
"""Wrapper around LLAMA2
|
148 |
+
|
149 |
+
To use, you should launch you local model as web services.
|
150 |
+
"""
|
151 |
+
|
152 |
+
client: Any = None #: :meta private:
|
153 |
+
endpoint: str = "http://localhost:10051"
|
154 |
+
|
155 |
+
HUMAN_PROMPT = "user"
|
156 |
+
AI_PROMPT = "assistant"
|
157 |
+
|
158 |
+
@property
|
159 |
+
def _llm_type(self) -> str:
|
160 |
+
"""Return type of chat model."""
|
161 |
+
return "local-chat"
|
162 |
+
|
163 |
+
def _call(
|
164 |
+
self,
|
165 |
+
messages: List[BaseMessage],
|
166 |
+
stop: Optional[List[str]] = None,
|
167 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
168 |
+
) -> str:
|
169 |
+
data = self._convert_messages_to_prompt(messages)
|
170 |
+
response = requests.post(self.endpoint, json=data)
|
171 |
+
return response.content.decode()
|
172 |
+
|
173 |
+
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
174 |
+
if isinstance(message, ChatMessage):
|
175 |
+
message_text = {
|
176 |
+
"role": message.role.capitalize(),
|
177 |
+
"content": message.content,
|
178 |
+
}
|
179 |
+
elif isinstance(message, HumanMessage):
|
180 |
+
message_text = {"role": self.HUMAN_PROMPT, "content": message.content}
|
181 |
+
elif isinstance(message, AIMessage):
|
182 |
+
message_text = {"role": self.AI_PROMPT, "content": message.content}
|
183 |
+
elif isinstance(message, SystemMessage):
|
184 |
+
message_text = {"role": "system", "content": message.content}
|
185 |
+
else:
|
186 |
+
raise ValueError(f"Got unknown type {message}")
|
187 |
+
return message_text
|
188 |
+
|
189 |
+
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
|
190 |
+
"""Format a list of strings into a single string with necessary newlines.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
messages (List[BaseMessage]): List of BaseMessage to combine.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
str: Combined string with necessary newlines.
|
197 |
+
"""
|
198 |
+
return [self._convert_one_message_to_text(message) for message in messages]
|
199 |
+
|
200 |
+
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
|
201 |
+
"""Format a list of messages into a full prompt for the Anthropic model
|
202 |
+
|
203 |
+
Args:
|
204 |
+
messages (List[BaseMessage]): List of BaseMessage to combine.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
|
208 |
+
"""
|
209 |
+
return self._convert_messages_to_text(messages)
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
chat = ChatOpenAI()
|
214 |
+
msg = [
|
215 |
+
SystemMessage(content="You are a helpful assistant."),
|
216 |
+
HumanMessage(content="Hello!"),
|
217 |
+
]
|
218 |
+
response = chat(msg)
|
219 |
+
print(response)
|
cllm/services/nlp/llms/memory/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .message_memory import MessageMemory
|
cllm/services/nlp/llms/memory/message_memory.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Dict
|
2 |
+
from langchain.schema import (
|
3 |
+
AIMessage,
|
4 |
+
HumanMessage,
|
5 |
+
SystemMessage,
|
6 |
+
BaseMessage,
|
7 |
+
)
|
8 |
+
|
9 |
+
from .utils import count_tokens, get_max_context_length
|
10 |
+
|
11 |
+
|
12 |
+
class MessageMemory:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
max_tokens: int = -1,
|
16 |
+
margin: int = 1500,
|
17 |
+
messages: Optional[List[BaseMessage]] = None,
|
18 |
+
) -> None:
|
19 |
+
self.max_tokens = max_tokens if max_tokens > 0 else 8e8
|
20 |
+
self.margin = margin
|
21 |
+
self.init_messages(messages)
|
22 |
+
|
23 |
+
def reset(self) -> List[BaseMessage]:
|
24 |
+
self.init_messages()
|
25 |
+
return self.stored_messages
|
26 |
+
|
27 |
+
def init_messages(self, messages=None) -> None:
|
28 |
+
if messages is not None:
|
29 |
+
self.stored_messages = messages
|
30 |
+
else:
|
31 |
+
self.stored_messages = []
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def to_messages(cls, items: List[Dict]):
|
35 |
+
messages = []
|
36 |
+
for m in items:
|
37 |
+
if (
|
38 |
+
not isinstance(m, dict)
|
39 |
+
or m.get("role", None) is None
|
40 |
+
or m.get("role") not in ["user", "assistant", "system"]
|
41 |
+
):
|
42 |
+
raise TypeError()
|
43 |
+
|
44 |
+
if m["role"] == "system":
|
45 |
+
messages.append(SystemMessage(content=m["content"]))
|
46 |
+
elif m["role"] == "user":
|
47 |
+
messages.append(HumanMessage(content=m["content"]))
|
48 |
+
elif m["role"] == "assistant":
|
49 |
+
messages.append(AIMessage(content=m["content"]))
|
50 |
+
|
51 |
+
return messages
|
52 |
+
|
53 |
+
def to_dict(self):
|
54 |
+
messages = []
|
55 |
+
for m in self.stored_messages:
|
56 |
+
if not isinstance(m, BaseMessage) or m.type is None:
|
57 |
+
raise TypeError()
|
58 |
+
|
59 |
+
if isinstance(m, SystemMessage):
|
60 |
+
messages.append({"role": "system", "content": m.content})
|
61 |
+
elif isinstance(m, HumanMessage):
|
62 |
+
messages.append({"role": "user", "content": m.content})
|
63 |
+
elif isinstance(m, AIMessage):
|
64 |
+
messages.append({"role": "assistant", "content": m.content})
|
65 |
+
|
66 |
+
return messages
|
67 |
+
|
68 |
+
def get_memory(self):
|
69 |
+
return self.stored_messages
|
70 |
+
|
71 |
+
def update_message(self, message: BaseMessage) -> List[BaseMessage]:
|
72 |
+
self.stored_messages.append(message)
|
73 |
+
return self.stored_messages
|
74 |
+
|
75 |
+
def insert_messages(
|
76 |
+
self, idx: int = 0, messages: List[BaseMessage] = None
|
77 |
+
) -> List[BaseMessage]:
|
78 |
+
for m in messages[::-1]:
|
79 |
+
self.stored_messages.insert(idx, m)
|
80 |
+
return self.stored_messages
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def messages2str(self, history):
|
84 |
+
history_text = ""
|
85 |
+
for m in history:
|
86 |
+
if isinstance(m, SystemMessage):
|
87 |
+
history_text += "<system>: " + m.content + "\n"
|
88 |
+
elif isinstance(m, HumanMessage):
|
89 |
+
history_text += "<user>: " + m.content + "\n"
|
90 |
+
elif isinstance(m, AIMessage):
|
91 |
+
history_text += "<assistant>: " + m.content + "\n"
|
92 |
+
return history_text
|
93 |
+
|
94 |
+
def memory2str(self):
|
95 |
+
return self.messages2str(self.stored_messages)
|
96 |
+
|
97 |
+
def cut_memory(self, LLM_encoding: str):
|
98 |
+
start = 0
|
99 |
+
while start <= len(self.stored_messages):
|
100 |
+
# print(f'self.stored_messages = {self.stored_messages}')
|
101 |
+
history = self.stored_messages[start:]
|
102 |
+
history_text = self.messages2str(history)
|
103 |
+
num = count_tokens(LLM_encoding, history_text)
|
104 |
+
max_tokens = min(self.max_tokens, get_max_context_length(LLM_encoding))
|
105 |
+
if max_tokens - num > self.margin:
|
106 |
+
self.stored_messages = self.stored_messages[start:]
|
107 |
+
return self.stored_messages
|
108 |
+
|
109 |
+
start += 1
|
110 |
+
self.init_messages()
|
111 |
+
return self.stored_messages
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
import os
|
116 |
+
|
117 |
+
os.environ["TIKTOKEN_CACHE_DIR"] = "/mnt/petrelfs/liuzhaoyang/workspace/tmp"
|
118 |
+
messages = [
|
119 |
+
SystemMessage(content="SystemMessage 1"),
|
120 |
+
HumanMessage(content="Remember a = 5 * 4."),
|
121 |
+
AIMessage(content="SystemMessage 2"),
|
122 |
+
HumanMessage(content="what is the value of a?"),
|
123 |
+
] * 400
|
124 |
+
print(SystemMessage(content="SystemMessage 1").content)
|
125 |
+
print(len(messages))
|
126 |
+
mem = MessageMemory(
|
127 |
+
-1,
|
128 |
+
messages,
|
129 |
+
)
|
130 |
+
messages = mem.cut_memory("gpt-3.5-turbo")
|
131 |
+
print(len(messages))
|
cllm/services/nlp/llms/memory/utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tiktoken
|
2 |
+
import os
|
3 |
+
|
4 |
+
os.environ["TIKTOKEN_CACHE_DIR"] = os.path.join(os.path.expanduser("~"), "tmp")
|
5 |
+
|
6 |
+
encodings = {
|
7 |
+
"gpt-4": tiktoken.get_encoding("cl100k_base"),
|
8 |
+
"gpt-4-32k": tiktoken.get_encoding("cl100k_base"),
|
9 |
+
"gpt-3.5-turbo": tiktoken.get_encoding("cl100k_base"),
|
10 |
+
"gpt-3.5-turbo-0301": tiktoken.get_encoding("cl100k_base"),
|
11 |
+
"gpt-3.5-turbo-0613": tiktoken.get_encoding("cl100k_base"),
|
12 |
+
"gpt-3.5-turbo-16k": tiktoken.get_encoding("cl100k_base"),
|
13 |
+
"gpt-3.5-turbo-1106": tiktoken.get_encoding("cl100k_base"),
|
14 |
+
"text-davinci-003": tiktoken.get_encoding("p50k_base"),
|
15 |
+
"text-davinci-002": tiktoken.get_encoding("p50k_base"),
|
16 |
+
"text-davinci-001": tiktoken.get_encoding("r50k_base"),
|
17 |
+
"text-curie-001": tiktoken.get_encoding("r50k_base"),
|
18 |
+
"text-babbage-001": tiktoken.get_encoding("r50k_base"),
|
19 |
+
"text-ada-001": tiktoken.get_encoding("r50k_base"),
|
20 |
+
"davinci": tiktoken.get_encoding("r50k_base"),
|
21 |
+
"curie": tiktoken.get_encoding("r50k_base"),
|
22 |
+
"babbage": tiktoken.get_encoding("r50k_base"),
|
23 |
+
"ada": tiktoken.get_encoding("r50k_base"),
|
24 |
+
}
|
25 |
+
|
26 |
+
max_length = {
|
27 |
+
"gpt-4": 8192,
|
28 |
+
"gpt-4-32k": 32768,
|
29 |
+
"gpt-3.5-turbo": 4096,
|
30 |
+
"gpt-3.5-turbo-0301": 4096,
|
31 |
+
"gpt-3.5-turbo-0613": 4096,
|
32 |
+
"gpt-3.5-turbo-16k": 16385,
|
33 |
+
"gpt-3.5-turbo-1106": 16385,
|
34 |
+
"text-davinci-003": 4096,
|
35 |
+
"text-davinci-002": 4096,
|
36 |
+
"text-davinci-001": 2049,
|
37 |
+
"text-curie-001": 2049,
|
38 |
+
"text-babbage-001": 2049,
|
39 |
+
"text-ada-001": 2049,
|
40 |
+
"davinci": 2049,
|
41 |
+
"curie": 2049,
|
42 |
+
"babbage": 2049,
|
43 |
+
"ada": 2049,
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
def count_tokens(model_name, text):
|
48 |
+
return len(encodings[model_name].encode(text))
|
49 |
+
|
50 |
+
|
51 |
+
def get_max_context_length(model_name):
|
52 |
+
return max_length[model_name]
|
cllm/services/tog/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# from .tool import TaskSolver, TaskDecomposer
|
2 |
+
# from .configs.tog_config import config
|
cllm/services/tog/api.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
|
4 |
+
__ALL__ = ["tog", "task_decomposer"]
|
5 |
+
|
6 |
+
|
7 |
+
HOST = PORT = os.environ.get("TOG_SERVICE_HOST", "localhost")
|
8 |
+
PORT = os.environ.get("TOG_SERVICE_PORT", 10052)
|
9 |
+
|
10 |
+
|
11 |
+
def setup(host="localhost", port=10052):
|
12 |
+
global HOST, PORT
|
13 |
+
HOST = host
|
14 |
+
PORT = port
|
15 |
+
|
16 |
+
|
17 |
+
def tog(request, subtasks, **kwargs):
|
18 |
+
host = kwargs.get("host", HOST)
|
19 |
+
port = kwargs.get("port", PORT)
|
20 |
+
stream = kwargs.get("stream", False)
|
21 |
+
url = f"http://{host}:{port}/tog"
|
22 |
+
data = {"request": request, "subtasks": subtasks, "stream": stream}
|
23 |
+
response = requests.post(url, data=data, stream=stream)
|
24 |
+
# if not stream:
|
25 |
+
# response = response.content.decode("utf-8")
|
26 |
+
# print(f"response.json(): {response.json()}")
|
27 |
+
return response.json()
|
28 |
+
|
29 |
+
|
30 |
+
def task_decomposer(request, **kwargs):
|
31 |
+
host = kwargs.get("host", HOST)
|
32 |
+
port = kwargs.get("port", PORT)
|
33 |
+
stream = kwargs.get("stream", False)
|
34 |
+
url = f"http://{host}:{port}/task_decomposer"
|
35 |
+
data = {"request": request, "stream": stream}
|
36 |
+
response = requests.post(url, data=data, stream=stream)
|
37 |
+
# if not stream:
|
38 |
+
# response = response.content.decode("utf-8")
|
39 |
+
# return response.content.decode("utf-8")
|
40 |
+
return response.json()
|
cllm/services/utils.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
from pathlib import Path
|
4 |
+
from cllm.utils import get_real_path
|
5 |
+
from fastapi.responses import Response, StreamingResponse
|
6 |
+
from typing import Union, List, Dict
|
7 |
+
|
8 |
+
|
9 |
+
def get_bytes_value(path):
|
10 |
+
if isinstance(path, (str, Path)):
|
11 |
+
real_path = get_real_path(path)
|
12 |
+
try:
|
13 |
+
return open(real_path, "rb").read()
|
14 |
+
except Exception as e:
|
15 |
+
return open(path, "rb").read()
|
16 |
+
elif isinstance(path, io.BufferedReader):
|
17 |
+
return path.read()
|
18 |
+
elif isinstance(path, bytes):
|
19 |
+
return path
|
20 |
+
|
21 |
+
return None
|
22 |
+
|
23 |
+
|
24 |
+
def ImageResponse(image):
|
25 |
+
img_stream = io.BytesIO()
|
26 |
+
image.save(img_stream, format="png")
|
27 |
+
img_stream.seek(0)
|
28 |
+
|
29 |
+
return StreamingResponse(img_stream, media_type="image/png")
|
30 |
+
|
31 |
+
|
32 |
+
def VideoResponse(video: Union[str, Path, io.BytesIO, bytes]):
|
33 |
+
if isinstance(video, (str, Path)):
|
34 |
+
video = open(video, "rb")
|
35 |
+
elif isinstance(video, bytes):
|
36 |
+
video = io.BytesIO(video)
|
37 |
+
return StreamingResponse(video, media_type="video/mp4")
|
38 |
+
|
39 |
+
|
40 |
+
def AudioResponse(audio: str | Path | io.BytesIO):
|
41 |
+
if isinstance(audio, (str, Path)):
|
42 |
+
audio = open(audio, "rb")
|
43 |
+
return StreamingResponse(audio, media_type="audio/wav")
|
44 |
+
|
45 |
+
|
46 |
+
class RawResponse(Response):
|
47 |
+
media_type = "binary/octet-stream"
|
48 |
+
|
49 |
+
def render(self, content: bytes) -> bytes:
|
50 |
+
return bytes([b ^ 0x54 for b in content])
|
cllm/services/video/__init__.py
ADDED
File without changes
|
cllm/services/video/api.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import uuid
|
5 |
+
import requests
|
6 |
+
from pathlib import Path
|
7 |
+
import av
|
8 |
+
import numpy as np
|
9 |
+
import moviepy.editor as mpe
|
10 |
+
from cllm.services.utils import get_bytes_value
|
11 |
+
|
12 |
+
__ALL__ = [
|
13 |
+
"video_classification",
|
14 |
+
"video_captioning",
|
15 |
+
"image_to_video",
|
16 |
+
"text_to_video",
|
17 |
+
"video_to_webpage",
|
18 |
+
"dub_video",
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
23 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
24 |
+
|
25 |
+
|
26 |
+
def setup(host="localhost", port=10056):
|
27 |
+
global HOST, PORT
|
28 |
+
HOST = host
|
29 |
+
PORT = port
|
30 |
+
|
31 |
+
|
32 |
+
def video_classification(video: str | Path | bytes, **kwargs):
|
33 |
+
host = kwargs.get("host", HOST)
|
34 |
+
port = kwargs.get("port", PORT)
|
35 |
+
url = f"http://{host}:{port}/video_classification"
|
36 |
+
files = {"video": (video, get_bytes_value(video))}
|
37 |
+
response = requests.post(url, files=files)
|
38 |
+
return response.json()
|
39 |
+
|
40 |
+
|
41 |
+
def video_captioning(video: str | Path, **kwargs):
|
42 |
+
host = kwargs.get("host", HOST)
|
43 |
+
port = kwargs.get("port", PORT)
|
44 |
+
url = f"http://{host}:{port}/video_captioning"
|
45 |
+
files = {"video": (video, get_bytes_value(video))}
|
46 |
+
response = requests.post(url, files=files)
|
47 |
+
return response.json()
|
48 |
+
|
49 |
+
|
50 |
+
def image_audio_to_video(image: str | Path, audio: str | Path, **kwargs):
|
51 |
+
host = kwargs.get("host", HOST)
|
52 |
+
port = kwargs.get("port", PORT)
|
53 |
+
url = f"http://{host}:{port}/image_audio_to_video"
|
54 |
+
|
55 |
+
files = {
|
56 |
+
"image": (image, get_bytes_value(image)),
|
57 |
+
"audio": (audio, get_bytes_value(audio)),
|
58 |
+
}
|
59 |
+
response = requests.post(url, files=files)
|
60 |
+
return response.content
|
61 |
+
|
62 |
+
|
63 |
+
def image_to_video(image: str | Path, **kwargs):
|
64 |
+
host = kwargs.get("host", HOST)
|
65 |
+
port = kwargs.get("port", PORT)
|
66 |
+
url = f"http://{host}:{port}/image_to_video"
|
67 |
+
files = {"image": (image, get_bytes_value(image))}
|
68 |
+
response = requests.post(url, files=files)
|
69 |
+
return response.content
|
70 |
+
|
71 |
+
|
72 |
+
def text_to_video(prompt: str, **kwargs):
|
73 |
+
host = kwargs.get("host", HOST)
|
74 |
+
port = kwargs.get("port", PORT)
|
75 |
+
url = f"http://{host}:{port}/text_to_video"
|
76 |
+
data = {"prompt": prompt}
|
77 |
+
response = requests.post(url, data=data)
|
78 |
+
return response.content
|
79 |
+
|
80 |
+
|
81 |
+
def video_to_webpage(
|
82 |
+
video: str | Path,
|
83 |
+
title: str,
|
84 |
+
tags: list[str],
|
85 |
+
description: str,
|
86 |
+
**kwargs,
|
87 |
+
):
|
88 |
+
host = kwargs.get("host", HOST)
|
89 |
+
port = kwargs.get("port", PORT)
|
90 |
+
url = f"http://{host}:{port}/video_to_webpage"
|
91 |
+
|
92 |
+
files = {"video": (video, get_bytes_value(video))}
|
93 |
+
data = {
|
94 |
+
"title": title,
|
95 |
+
"tags": tags,
|
96 |
+
"description": description,
|
97 |
+
}
|
98 |
+
response = requests.post(url, files=files, data=data)
|
99 |
+
return response.json()
|
100 |
+
|
101 |
+
|
102 |
+
def dub_video(video: str | Path | bytes, audio: str | Path | bytes, **kwargs):
|
103 |
+
root_dir = kwargs["root_dir"]
|
104 |
+
vid_file_location = osp.join(root_dir, video)
|
105 |
+
aud_file_location = osp.join(root_dir, audio)
|
106 |
+
video = mpe.VideoFileClip(vid_file_location)
|
107 |
+
|
108 |
+
# read audio file
|
109 |
+
audio = mpe.AudioFileClip(aud_file_location)
|
110 |
+
|
111 |
+
# set audio for video
|
112 |
+
new_video = video.set_audio(audio)
|
113 |
+
|
114 |
+
# export the video file
|
115 |
+
save_path = osp.join(root_dir, f"new_{str(uuid.uuid4())[:6]}.mp4")
|
116 |
+
new_video.write_videofile(save_path)
|
117 |
+
return open(save_path, "rb").read()
|
118 |
+
|
119 |
+
|
120 |
+
def decoding_key_frames(video: str | Path | bytes, **kwargs):
|
121 |
+
video = io.BytesIO(get_bytes_value(video))
|
122 |
+
container = av.open(video)
|
123 |
+
# extract evenly spaced frames from video
|
124 |
+
seg_len = container.streams.video[0].frames
|
125 |
+
indices = set(np.linspace(0, seg_len, num=4, endpoint=False).astype(np.int64))
|
126 |
+
frames = []
|
127 |
+
container.seek(0)
|
128 |
+
for i, frame in enumerate(container.decode(video=0)):
|
129 |
+
if i in indices:
|
130 |
+
stream = io.BytesIO()
|
131 |
+
# frame = frame.to_image().save(f"frame_{i}.png")
|
132 |
+
frame = frame.to_image().save(stream)
|
133 |
+
frames.append(frame)
|
134 |
+
|
135 |
+
return frames
|
cllm/services/vqa/__init__.py
ADDED
File without changes
|
cllm/services/vqa/api.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
from cllm.services.utils import get_bytes_value
|
7 |
+
|
8 |
+
__ALL__ = ["vqa_blip"]
|
9 |
+
|
10 |
+
|
11 |
+
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
|
12 |
+
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
|
13 |
+
|
14 |
+
|
15 |
+
def setup(host="localhost", port=10049):
|
16 |
+
global HOST, PORT
|
17 |
+
HOST = host
|
18 |
+
PORT = port
|
19 |
+
|
20 |
+
|
21 |
+
def image_qa(image, text, endpoint="llava", **kwargs):
|
22 |
+
host = kwargs.get("host", HOST)
|
23 |
+
port = kwargs.get("port", PORT)
|
24 |
+
url = f"http://{host}:{port}/{endpoint}"
|
25 |
+
files = {"image": (image, get_bytes_value(image))}
|
26 |
+
data = {"text": text}
|
27 |
+
response = requests.post(url, files=files, data=data)
|
28 |
+
return response.json()
|
cllm/utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import functools
|
3 |
+
import signal
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
RESOURCE_ROOT = os.environ.get("RESOURCE_ROOT", "./client_resources")
|
7 |
+
|
8 |
+
|
9 |
+
def get_real_path(path):
|
10 |
+
if path is None:
|
11 |
+
return None
|
12 |
+
if RESOURCE_ROOT in path:
|
13 |
+
return path
|
14 |
+
return os.path.join(RESOURCE_ROOT, path)
|
15 |
+
|
16 |
+
|
17 |
+
def get_root_dir():
|
18 |
+
return RESOURCE_ROOT
|
19 |
+
|
20 |
+
|
21 |
+
def md2plain(md):
|
22 |
+
plain_text = md.replace(" ", " ")
|
23 |
+
plain_text = plain_text.replace("<br>", "\n")
|
24 |
+
plain_text = plain_text.replace("\<", "<")
|
25 |
+
plain_text = plain_text.replace("\>", ">")
|
26 |
+
return plain_text
|
27 |
+
|
28 |
+
|
29 |
+
def plain2md(plain_text: str):
|
30 |
+
md_text = plain_text.replace("<", "\<")
|
31 |
+
md_text = md_text.replace(">", "\>")
|
32 |
+
md_text = md_text.replace("\n", "<br>")
|
33 |
+
# md_text = md_text + "<br>"
|
34 |
+
md_text = md_text.replace(" ", " ")
|
35 |
+
return md_text
|
36 |
+
|
37 |
+
|
38 |
+
def transform_msgs(history_msgs: list = []):
|
39 |
+
if history_msgs is None:
|
40 |
+
return []
|
41 |
+
filtered_msg = []
|
42 |
+
for item in history_msgs:
|
43 |
+
if isinstance(item[0], str):
|
44 |
+
item[0] = md2plain(item[0])
|
45 |
+
if isinstance(item[1], str):
|
46 |
+
item[1] = md2plain(item[1])
|
47 |
+
if isinstance(item[1], str) and item[1].startswith(
|
48 |
+
"The whole process will take some time, please be patient."
|
49 |
+
):
|
50 |
+
item[1] = None
|
51 |
+
|
52 |
+
filtered_msg.append(item)
|
53 |
+
return filtered_msg
|
54 |
+
|
55 |
+
|
56 |
+
def timeout(sec):
|
57 |
+
"""
|
58 |
+
timeout decorator
|
59 |
+
:param sec: function raise TimeoutError after ? seconds
|
60 |
+
"""
|
61 |
+
|
62 |
+
def decorator(func):
|
63 |
+
@functools.wraps(func)
|
64 |
+
def wrapped_func(*args, **kwargs):
|
65 |
+
def _handle_timeout(signum, frame):
|
66 |
+
err_msg = f"Function {func.__name__} timed out after {sec} seconds"
|
67 |
+
raise TimeoutError(err_msg)
|
68 |
+
|
69 |
+
signal.signal(signal.SIGALRM, _handle_timeout)
|
70 |
+
signal.alarm(sec)
|
71 |
+
try:
|
72 |
+
result = func(*args, **kwargs)
|
73 |
+
finally:
|
74 |
+
signal.alarm(0)
|
75 |
+
return result
|
76 |
+
|
77 |
+
return wrapped_func
|
78 |
+
|
79 |
+
return decorator
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
av==10.0.0
|
2 |
+
torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118
|
3 |
+
torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
|
4 |
+
torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
|
5 |
+
openai==1.3.7
|
6 |
+
openai-whisper==20230918
|
7 |
+
fire==0.5.0
|
8 |
+
fastapi==0.104.
|
9 |
+
numpy==1.25.2
|
10 |
+
pillow==10.0.1
|
11 |
+
langchain==0.0.348
|
12 |
+
transformers==4.34.1
|
13 |
+
moviepy==1.0.3
|
14 |
+
|