|
import json |
|
from abc import ABC |
|
from distutils.util import strtobool |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import yaml |
|
from omagent_core.base import BotBase |
|
from omagent_core.models.od.schemas import Target |
|
from omagent_core.services.handlers.sql_data_handler import SQLDataHandler |
|
from omagent_core.utils.error import VQLError |
|
from omagent_core.utils.logger import logging |
|
from omagent_core.utils.plot import Annotator |
|
from PIL import Image |
|
from pydantic import BaseModel, model_validator |
|
|
|
|
|
class ArgSchema(BaseModel): |
|
"""ArgSchema defines the tool input schema. Only support one layer definition. Please prevent using complex structure.""" |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = "allow" |
|
arbitrary_types_allowed = True |
|
|
|
class ArgInfo(BaseModel): |
|
description: Optional[str] |
|
type: str = "str" |
|
enum: Optional[List] = None |
|
required: Optional[bool] = True |
|
|
|
@model_validator(mode="before") |
|
@classmethod |
|
def validate_all(cls, values): |
|
for key, value in values.items(): |
|
if type(value) is str: |
|
values[key] = cls.ArgInfo(name=value) |
|
elif type(value) is dict: |
|
values[key] = cls.ArgInfo(**value) |
|
elif type(value) is cls.ArgInfo: |
|
pass |
|
else: |
|
raise ValueError( |
|
"The arg type must be one of string, dict or self.ArgInfo." |
|
) |
|
return values |
|
|
|
@classmethod |
|
def from_file(cls, schema_file: Union[str, Path]): |
|
if type(schema_file) is str: |
|
schema_file = Path(schema_file) |
|
if schema_file.suffix == ".json": |
|
with open(schema_file, "r") as f: |
|
schema = json.load(f) |
|
elif schema_file.suffix == ".yaml": |
|
with open(schema_file, "r") as f: |
|
schema = yaml.load(f, Loader=yaml.FullLoader) |
|
else: |
|
raise ValueError("Only support json and yaml file.") |
|
return cls(**schema) |
|
|
|
def generate_schema(self) -> Union[dict, list]: |
|
required_args = [] |
|
parameters = {} |
|
for key, value in self.model_dump(exclude_none=True).items(): |
|
parameters[key] = value |
|
if parameters[key].pop("required"): |
|
required_args.append(key) |
|
return parameters, required_args |
|
|
|
def validate_args(self, args: dict) -> dict: |
|
if type(args) is not dict: |
|
raise ValueError( |
|
"ArgSchema validate only support dict, not {}".format(type(args)) |
|
) |
|
new_args = {} |
|
required_fields = set( |
|
[k for k, v in self.model_dump().items() if v["required"]] |
|
) |
|
name_mapping = { |
|
"str": "string", |
|
"int": "integer", |
|
"float": "number", |
|
"bool": "boolean", |
|
} |
|
for name, value in args.items(): |
|
if name not in self.model_dump(): |
|
logging.warning( |
|
"The input args includes an unnecessary parameter {}. Removed from the args.".format( |
|
name |
|
) |
|
) |
|
continue |
|
if name_mapping[type(value).__name__] == self.model_dump()[name]["type"]: |
|
if ( |
|
self.model_dump()[name]["enum"] |
|
and value not in self.model_dump()[name]["enum"] |
|
): |
|
raise ValueError( |
|
"The value of {} should be one of {}, but got {}".format( |
|
name, str(self.model_dump()[name]["enum"]), value |
|
) |
|
) |
|
new_args[name] = value |
|
elif self.model_dump()[name]["type"] == "string": |
|
try: |
|
new_args[name] = str(value) |
|
except: |
|
raise ValueError( |
|
"Parameter {} type expect a str value, but got a {} {}".format( |
|
name, type(value), value |
|
) |
|
) |
|
elif self.model_dump()[name]["type"] == "integer": |
|
try: |
|
new_args[name] = int(value) |
|
except: |
|
raise ValueError( |
|
"Parameter {} type expect an int value, but got a {} {}".format( |
|
name, type(value), value |
|
) |
|
) |
|
elif self.model_dump()[name]["type"] == "number": |
|
try: |
|
new_args[name] = float(value) |
|
except: |
|
raise ValueError( |
|
"Parameter {} type expect a float value, but got a {} {}".format( |
|
name, type(value), value |
|
) |
|
) |
|
elif self.model_dump()[name]["type"] == "boolean": |
|
if type(value) is bool: |
|
new_args[name] = value |
|
else: |
|
try: |
|
new_args[name] = strtobool(str(value)) |
|
except: |
|
raise ValueError( |
|
"Parameter {} type expect a boolean value, but got a {} {}".format( |
|
name, type(value), value |
|
) |
|
) |
|
else: |
|
raise ValueError( |
|
"Parameter {} type expect one of string, integer, number and boolean, but got a {} {}".format( |
|
name, self.model_dump()[name]["type"], type(value), value |
|
) |
|
) |
|
|
|
if required_fields - set(new_args.keys()): |
|
raise VQLError( |
|
"The required fields {} are missing.".format( |
|
required_fields - set(new_args.keys()) |
|
) |
|
) |
|
return new_args |
|
|
|
|
|
class BaseTool(BotBase, ABC): |
|
description: str |
|
func: Optional[Callable] = None |
|
args_schema: Optional[ArgSchema] |
|
special_params: Dict = {} |
|
|
|
def model_post_init(self, __context: Any) -> None: |
|
for _, attr_value in self.__dict__.items(): |
|
if isinstance(attr_value, BotBase): |
|
attr_value._parent = self |
|
|
|
@property |
|
def workflow_instance_id(self) -> str: |
|
if hasattr(self, "_parent"): |
|
return self._parent.workflow_instance_id |
|
return None |
|
|
|
@workflow_instance_id.setter |
|
def workflow_instance_id(self, value: str): |
|
if hasattr(self, "_parent"): |
|
self._parent.workflow_instance_id = value |
|
|
|
def _run(self, **input) -> str: |
|
"""Implement this function or pass 'func' arg when initializing.""" |
|
return self.func(**input) |
|
|
|
async def _arun(self, **input) -> str: |
|
"""Implement this function or pass 'func' arg when initializing.""" |
|
return await self.func(**input) |
|
|
|
def run(self, input: Any) -> str: |
|
if self.args_schema != None: |
|
if type(input) != dict: |
|
raise ValueError( |
|
"The input type must be dict when args_schema is specified." |
|
) |
|
self.args_schema.validate_args(input) |
|
return self._run(**input, **self.special_params) |
|
|
|
async def arun(self, input: Any) -> str: |
|
if self.args_schema != None: |
|
if type(input) != dict: |
|
raise ValueError( |
|
"The input type must be dict when args_schema is specified." |
|
) |
|
self.args_schema.validate_args(input) |
|
return await self._arun(**input, **self.special_params) |
|
|
|
def generate_schema(self): |
|
if not self.args_schema: |
|
return { |
|
"type": "function", |
|
"description": self.description, |
|
"function": { |
|
"name": self.name, |
|
"parameters": { |
|
"type": "object", |
|
"name": "input", |
|
"required": ["input"], |
|
}, |
|
}, |
|
} |
|
else: |
|
properties, required = self.args_schema.generate_schema() |
|
return { |
|
"type": "function", |
|
"function": { |
|
"name": self.name, |
|
"description": self.description, |
|
"parameters": { |
|
"type": "object", |
|
"properties": properties, |
|
"required": required, |
|
}, |
|
}, |
|
} |
|
|
|
|
|
class BaseModelTool(BaseTool, ABC): |
|
|
|
|
|
def visual_prompting( |
|
self, |
|
image: Image.Image, |
|
annotation: List[Target], |
|
prompting_type: str = "label_on_img", |
|
include_labels: Union[List, set, tuple] = None, |
|
exclude_labels: Union[List, set, tuple] = None, |
|
) -> List[Image.Image]: |
|
annotator = Annotator(image) |
|
for obj in annotation: |
|
if (exclude_labels is not None and obj.label in exclude_labels) or ( |
|
include_labels is not None and obj.label not in include_labels |
|
): |
|
continue |
|
if obj.bbox: |
|
annotator.box_label(obj.bbox, obj.label, color="red") |
|
|
|
return annotator.result() |
|
|
|
def infer(self, images: List[Image.Image], kwargs) -> List[List[Target]]: |
|
"""The model inference step. Only support OD type detection. |
|
|
|
Args: |
|
images (List[Image.Image]): The list of input images. Image should be PIL Image object. |
|
kwargs (dict): The additional arguments for the model. |
|
|
|
Returns: |
|
List[List[Target]]: The detection results. |
|
""" |
|
|
|
def ainfer(self, images: List[Image.Image], kwargs) -> List[List[Target]]: |
|
"""The async version of model inference step. Only support OD type detection. |
|
|
|
Args: |
|
images (List[Image.Image]): The list of input images. Image should be PIL Image object. |
|
kwargs (dict): The additional arguments for the model. |
|
|
|
Returns: |
|
List[List[Target]]: The detection results. |
|
""" |
|
|
|
|
|
class MemoryTool(BaseTool): |
|
memory_handler: Optional[SQLDataHandler] |
|
|
|
def generate_schema(self) -> dict: |
|
"""Generate the data table schema in dict format. |
|
|
|
Returns: |
|
dict: The data table schema. Including the table name, and the name, data type and additional information of each column. |
|
""" |
|
table = self.memory_handler.table |
|
schema = {"table_name": table.__tablename__, "columns": []} |
|
for column in table.__table__.columns: |
|
schema["columns"].append( |
|
{ |
|
"name": column.name, |
|
"type": column.type.__visit_name__, |
|
"info": column.info, |
|
} |
|
) |
|
return schema |
|
|
|
def generate_prompt(self): |
|
pass |
|
|
|
def _run(self): |
|
self.memory_handler.execute_sql() |
|
|
|
async def _arun(self): |
|
self.memory_handler.execute_sql() |
|
|