Spaces:
Paused
Paused
# | |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import importlib | |
import json | |
import traceback | |
from abc import ABC | |
from copy import deepcopy | |
from functools import partial | |
import pandas as pd | |
from agent.component import component_class | |
from agent.component.base import ComponentBase | |
from agent.settings import flow_logger, DEBUG | |
class Canvas(ABC): | |
""" | |
dsl = { | |
"components": { | |
"begin": { | |
"obj":{ | |
"component_name": "Begin", | |
"params": {}, | |
}, | |
"downstream": ["answer_0"], | |
"upstream": [], | |
}, | |
"answer_0": { | |
"obj": { | |
"component_name": "Answer", | |
"params": {} | |
}, | |
"downstream": ["retrieval_0"], | |
"upstream": ["begin", "generate_0"], | |
}, | |
"retrieval_0": { | |
"obj": { | |
"component_name": "Retrieval", | |
"params": {} | |
}, | |
"downstream": ["generate_0"], | |
"upstream": ["answer_0"], | |
}, | |
"generate_0": { | |
"obj": { | |
"component_name": "Generate", | |
"params": {} | |
}, | |
"downstream": ["answer_0"], | |
"upstream": ["retrieval_0"], | |
} | |
}, | |
"history": [], | |
"messages": [], | |
"reference": [], | |
"path": [["begin"]], | |
"answer": [] | |
} | |
""" | |
def __init__(self, dsl: str, tenant_id=None): | |
self.path = [] | |
self.history = [] | |
self.messages = [] | |
self.answer = [] | |
self.components = {} | |
self.dsl = json.loads(dsl) if dsl else { | |
"components": { | |
"begin": { | |
"obj": { | |
"component_name": "Begin", | |
"params": { | |
"prologue": "Hi there!" | |
} | |
}, | |
"downstream": [], | |
"upstream": [] | |
} | |
}, | |
"history": [], | |
"messages": [], | |
"reference": [], | |
"path": [], | |
"answer": [] | |
} | |
self._tenant_id = tenant_id | |
self._embed_id = "" | |
self.load() | |
def load(self): | |
self.components = self.dsl["components"] | |
cpn_nms = set([]) | |
for k, cpn in self.components.items(): | |
cpn_nms.add(cpn["obj"]["component_name"]) | |
assert "Begin" in cpn_nms, "There have to be an 'Begin' component." | |
assert "Answer" in cpn_nms, "There have to be an 'Answer' component." | |
for k, cpn in self.components.items(): | |
cpn_nms.add(cpn["obj"]["component_name"]) | |
param = component_class(cpn["obj"]["component_name"] + "Param")() | |
param.update(cpn["obj"]["params"]) | |
param.check() | |
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param) | |
if cpn["obj"].component_name == "Categorize": | |
for _, desc in param.category_description.items(): | |
if desc["to"] not in cpn["downstream"]: | |
cpn["downstream"].append(desc["to"]) | |
self.path = self.dsl["path"] | |
self.history = self.dsl["history"] | |
self.messages = self.dsl["messages"] | |
self.answer = self.dsl["answer"] | |
self.reference = self.dsl["reference"] | |
self._embed_id = self.dsl.get("embed_id", "") | |
def __str__(self): | |
self.dsl["path"] = self.path | |
self.dsl["history"] = self.history | |
self.dsl["messages"] = self.messages | |
self.dsl["answer"] = self.answer | |
self.dsl["reference"] = self.reference | |
self.dsl["embed_id"] = self._embed_id | |
dsl = { | |
"components": {} | |
} | |
for k in self.dsl.keys(): | |
if k in ["components"]:continue | |
dsl[k] = deepcopy(self.dsl[k]) | |
for k, cpn in self.components.items(): | |
if k not in dsl["components"]: | |
dsl["components"][k] = {} | |
for c in cpn.keys(): | |
if c == "obj": | |
dsl["components"][k][c] = json.loads(str(cpn["obj"])) | |
continue | |
dsl["components"][k][c] = deepcopy(cpn[c]) | |
return json.dumps(dsl, ensure_ascii=False) | |
def reset(self): | |
self.path = [] | |
self.history = [] | |
self.messages = [] | |
self.answer = [] | |
self.reference = [] | |
for k, cpn in self.components.items(): | |
self.components[k]["obj"].reset() | |
self._embed_id = "" | |
def run(self, **kwargs): | |
ans = "" | |
if self.answer: | |
cpn_id = self.answer[0] | |
self.answer.pop(0) | |
try: | |
ans = self.components[cpn_id]["obj"].run(self.history, **kwargs) | |
except Exception as e: | |
ans = ComponentBase.be_output(str(e)) | |
self.path[-1].append(cpn_id) | |
if kwargs.get("stream"): | |
assert isinstance(ans, partial) | |
return ans | |
self.history.append(("assistant", ans.to_dict("records"))) | |
return ans | |
if not self.path: | |
self.components["begin"]["obj"].run(self.history, **kwargs) | |
self.path.append(["begin"]) | |
self.path.append([]) | |
ran = -1 | |
def prepare2run(cpns): | |
nonlocal ran, ans | |
for c in cpns: | |
if self.path[-1] and c == self.path[-1][-1]: continue | |
cpn = self.components[c]["obj"] | |
if cpn.component_name == "Answer": | |
self.answer.append(c) | |
else: | |
if DEBUG: print("RUN: ", c) | |
if cpn.component_name == "Generate": | |
cpids = cpn.get_dependent_components() | |
if any([c not in self.path[-1] for c in cpids]): | |
continue | |
ans = cpn.run(self.history, **kwargs) | |
self.path[-1].append(c) | |
ran += 1 | |
prepare2run(self.components[self.path[-2][-1]]["downstream"]) | |
while 0 <= ran < len(self.path[-1]): | |
if DEBUG: print(ran, self.path) | |
cpn_id = self.path[-1][ran] | |
cpn = self.get_component(cpn_id) | |
if not cpn["downstream"]: break | |
loop = self._find_loop() | |
if loop: raise OverflowError(f"Too much loops: {loop}") | |
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]: | |
switch_out = cpn["obj"].output()[1].iloc[0, 0] | |
assert switch_out in self.components, \ | |
"{}'s output: {} not valid.".format(cpn_id, switch_out) | |
try: | |
prepare2run([switch_out]) | |
except Exception as e: | |
for p in [c for p in self.path for c in p][::-1]: | |
if p.lower().find("answer") >= 0: | |
self.get_component(p)["obj"].set_exception(e) | |
prepare2run([p]) | |
break | |
traceback.print_exc() | |
break | |
continue | |
try: | |
prepare2run(cpn["downstream"]) | |
except Exception as e: | |
for p in [c for p in self.path for c in p][::-1]: | |
if p.lower().find("answer") >= 0: | |
self.get_component(p)["obj"].set_exception(e) | |
prepare2run([p]) | |
break | |
traceback.print_exc() | |
break | |
if self.answer: | |
cpn_id = self.answer[0] | |
self.answer.pop(0) | |
ans = self.components[cpn_id]["obj"].run(self.history, **kwargs) | |
self.path[-1].append(cpn_id) | |
if kwargs.get("stream"): | |
assert isinstance(ans, partial) | |
return ans | |
self.history.append(("assistant", ans.to_dict("records"))) | |
return ans | |
def get_component(self, cpn_id): | |
return self.components[cpn_id] | |
def get_tenant_id(self): | |
return self._tenant_id | |
def get_history(self, window_size): | |
convs = [] | |
for role, obj in self.history[window_size * -2:]: | |
convs.append({"role": role, "content": (obj if role == "user" else | |
'\n'.join(pd.DataFrame(obj)['content']))}) | |
return convs | |
def add_user_input(self, question): | |
self.history.append(("user", question)) | |
def set_embedding_model(self, embed_id): | |
self._embed_id = embed_id | |
def get_embedding_model(self): | |
return self._embed_id | |
def _find_loop(self, max_loops=2): | |
path = self.path[-1][::-1] | |
if len(path) < 2: return False | |
for i in range(len(path)): | |
if path[i].lower().find("answer") >= 0: | |
path = path[:i] | |
break | |
if len(path) < 2: return False | |
for l in range(2, len(path) // 2): | |
pat = ",".join(path[0:l]) | |
path_str = ",".join(path) | |
if len(pat) >= len(path_str): return False | |
loop = max_loops | |
while path_str.find(pat) == 0 and loop >= 0: | |
loop -= 1 | |
if len(pat)+1 >= len(path_str): | |
return False | |
path_str = path_str[len(pat)+1:] | |
if loop < 0: | |
pat = " => ".join([p.split(":")[0] for p in path[0:l]]) | |
return pat + " => " + pat | |
return False | |