Spaces:
Sleeping
Sleeping
lindsay-qu
commited on
Commit
•
e0f406c
1
Parent(s):
5cb2ac9
Upload 86 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +58 -0
- core/__init__.py +9 -0
- core/__pycache__/__init__.cpython-311.pyc +0 -0
- core/__pycache__/__init__.cpython-38.pyc +0 -0
- core/__pycache__/__init__.cpython-39.pyc +0 -0
- core/chain/__init__.py +1 -0
- core/chain/__pycache__/__init__.cpython-311.pyc +0 -0
- core/chain/__pycache__/__init__.cpython-38.pyc +0 -0
- core/chain/__pycache__/__init__.cpython-39.pyc +0 -0
- core/chain/__pycache__/base_chain.cpython-311.pyc +0 -0
- core/chain/__pycache__/base_chain.cpython-38.pyc +0 -0
- core/chain/__pycache__/base_chain.cpython-39.pyc +0 -0
- core/chain/base_chain.py +10 -0
- core/chain/simple_chain.py +19 -0
- core/chatbot/__init__.py +2 -0
- core/chatbot/__pycache__/__init__.cpython-311.pyc +0 -0
- core/chatbot/__pycache__/__init__.cpython-39.pyc +0 -0
- core/chatbot/__pycache__/base_chatbot.cpython-311.pyc +0 -0
- core/chatbot/__pycache__/base_chatbot.cpython-39.pyc +0 -0
- core/chatbot/__pycache__/retrieval_chatbot.cpython-311.pyc +0 -0
- core/chatbot/__pycache__/retrieval_chatbot.cpython-39.pyc +0 -0
- core/chatbot/base_chatbot.py +12 -0
- core/chatbot/retrieval_chatbot.py +69 -0
- core/memory/__init__.py +2 -0
- core/memory/__pycache__/__init__.cpython-311.pyc +0 -0
- core/memory/__pycache__/__init__.cpython-39.pyc +0 -0
- core/memory/__pycache__/base_memory.cpython-311.pyc +0 -0
- core/memory/__pycache__/base_memory.cpython-39.pyc +0 -0
- core/memory/__pycache__/chat_memory.cpython-311.pyc +0 -0
- core/memory/__pycache__/chat_memory.cpython-39.pyc +0 -0
- core/memory/base_memory.py +18 -0
- core/memory/chat_memory.py +22 -0
- core/memory/plan_memory.py +38 -0
- core/planner/__init__.py +1 -0
- core/planner/__pycache__/__init__.cpython-311.pyc +0 -0
- core/planner/__pycache__/__init__.cpython-39.pyc +0 -0
- core/planner/__pycache__/base_planner.cpython-311.pyc +0 -0
- core/planner/__pycache__/base_planner.cpython-39.pyc +0 -0
- core/planner/base_planner.py +6 -0
- core/refiner/__init__.py +2 -0
- core/refiner/__pycache__/__init__.cpython-311.pyc +0 -0
- core/refiner/__pycache__/__init__.cpython-39.pyc +0 -0
- core/refiner/__pycache__/base_refiner.cpython-311.pyc +0 -0
- core/refiner/__pycache__/base_refiner.cpython-39.pyc +0 -0
- core/refiner/__pycache__/simple_refiner.cpython-311.pyc +0 -0
- core/refiner/__pycache__/simple_refiner.cpython-39.pyc +0 -0
- core/refiner/base_refiner.py +11 -0
- core/refiner/recursive_refiner.py +0 -0
- core/refiner/simple_refiner.py +50 -0
- core/retriever/__init__.py +3 -0
app.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import core
|
2 |
+
import openai
|
3 |
+
import models
|
4 |
+
import time
|
5 |
+
import gradio as gr
|
6 |
+
import os
|
7 |
+
import asyncio
|
8 |
+
import time
|
9 |
+
|
10 |
+
api_key = os.environ["OPENAI_API_KEY"]
|
11 |
+
api_base = os.environ["OPENAI_API_BASE"]
|
12 |
+
|
13 |
+
# dddd
|
14 |
+
# def embed(texts: list):
|
15 |
+
# return openai.Embedding.create(input=texts, model="text-embedding-ada-002")["data"]["embedding"]
|
16 |
+
|
17 |
+
def chatbot_initialize():
|
18 |
+
retriever = core.retriever.ChromaRetriever(pdf_dir="",
|
19 |
+
collection_name="pdfs_1000",
|
20 |
+
split_args={"size": 2048, "overlap": 10}, #embedding_model="text-embedding-ada-002"
|
21 |
+
embed_model=models.BiomedModel()
|
22 |
+
)
|
23 |
+
Chatbot = core.chatbot.RetrievalChatbot(retriever=retriever)
|
24 |
+
return Chatbot
|
25 |
+
|
26 |
+
async def respond(query, chat_history, img_path_list, chat_history_string):
|
27 |
+
time1 = time.time()
|
28 |
+
global Chatbot
|
29 |
+
result = await Chatbot.response(query, image_paths=img_path_list)
|
30 |
+
response = result["answer"]
|
31 |
+
logs = result["logs"]
|
32 |
+
titles_set = result["titles"]
|
33 |
+
titles = "\n".join(list(titles_set))
|
34 |
+
chat_history.append((query, response))
|
35 |
+
if img_path_list is None:
|
36 |
+
chat_history_string += "Query: " + query + "\nImage: None" + "\nResponse: " + response + "\n\n\n"
|
37 |
+
else:
|
38 |
+
chat_history_string += "Query: " + query + "\nImages: " + "\n".join([path.name for path in img_path_list]) + "\nResponse: " + response + "\n\n\n"
|
39 |
+
time2 = time.time()
|
40 |
+
print(f"Total: {time2-time1}")
|
41 |
+
return "", chat_history, chat_history_string
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
global Chatbot
|
45 |
+
Chatbot=chatbot_initialize()
|
46 |
+
|
47 |
+
with gr.Blocks() as demo:
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column(scale=2):
|
50 |
+
chatbot = gr.Chatbot()
|
51 |
+
msg = gr.Textbox(label="Query", show_label=True)
|
52 |
+
imgs = gr.File(file_count='multiple', file_types=['image'], type="filepath", label='Upload Images')
|
53 |
+
clear = gr.ClearButton([msg, chatbot])
|
54 |
+
with gr.Column(scale=1):
|
55 |
+
# titles = gr.Textbox(label="Referenced Article Titles", show_label=True, show_copy_button=True, interactive=False)
|
56 |
+
history = gr.Textbox(label="Copy Chat History", show_label=True, show_copy_button=True, interactive=False, max_lines=5)
|
57 |
+
msg.submit(respond, inputs=[msg, chatbot, imgs, history], outputs=[msg, chatbot, history])
|
58 |
+
demo.queue().launch()
|
core/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .chain import *
|
2 |
+
from .chatbot import *
|
3 |
+
from .memory import *
|
4 |
+
from .planner import *
|
5 |
+
from .refiner import *
|
6 |
+
from .retriever import *
|
7 |
+
|
8 |
+
from models import *
|
9 |
+
from prompts import *
|
core/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (387 Bytes). View file
|
|
core/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (291 Bytes). View file
|
|
core/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (291 Bytes). View file
|
|
core/chain/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base_chain import BaseChain
|
core/chain/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (213 Bytes). View file
|
|
core/chain/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (181 Bytes). View file
|
|
core/chain/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (200 Bytes). View file
|
|
core/chain/__pycache__/base_chain.cpython-311.pyc
ADDED
Binary file (892 Bytes). View file
|
|
core/chain/__pycache__/base_chain.cpython-38.pyc
ADDED
Binary file (731 Bytes). View file
|
|
core/chain/__pycache__/base_chain.cpython-39.pyc
ADDED
Binary file (750 Bytes). View file
|
|
core/chain/base_chain.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class BaseChain:
|
3 |
+
def __init__(self, chain: list):
|
4 |
+
raise NotImplementedError
|
5 |
+
|
6 |
+
def append(self, item: str):
|
7 |
+
raise NotImplementedError
|
8 |
+
|
9 |
+
def execute(self):
|
10 |
+
raise NotImplementedError
|
core/chain/simple_chain.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from base_chain import BaseChain
|
2 |
+
|
3 |
+
# class SimpleChain(BaseChain):
|
4 |
+
# def __init__(self, chain: list[str]):
|
5 |
+
# self.chain = chain if chain else []
|
6 |
+
|
7 |
+
# def append(self, item: str):
|
8 |
+
# self.chain.append(item)
|
9 |
+
|
10 |
+
# def execute(self):
|
11 |
+
# # raise NotImplementedError
|
12 |
+
# for item in self.chain:
|
13 |
+
# pass
|
14 |
+
# #todo: execute item
|
15 |
+
# # item --> result
|
16 |
+
# item.execute(param=param)
|
17 |
+
# # result --> next item
|
18 |
+
|
19 |
+
# return result
|
core/chatbot/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base_chatbot import BaseChatbot
|
2 |
+
from .retrieval_chatbot import RetrievalChatbot
|
core/chatbot/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (293 Bytes). View file
|
|
core/chatbot/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (264 Bytes). View file
|
|
core/chatbot/__pycache__/base_chatbot.cpython-311.pyc
ADDED
Binary file (1 kB). View file
|
|
core/chatbot/__pycache__/base_chatbot.cpython-39.pyc
ADDED
Binary file (784 Bytes). View file
|
|
core/chatbot/__pycache__/retrieval_chatbot.cpython-311.pyc
ADDED
Binary file (5.38 kB). View file
|
|
core/chatbot/__pycache__/retrieval_chatbot.cpython-39.pyc
ADDED
Binary file (3.38 kB). View file
|
|
core/chatbot/base_chatbot.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import BaseModel
|
2 |
+
from ..memory import BaseMemory
|
3 |
+
class BaseChatbot:
|
4 |
+
def __init__(self,
|
5 |
+
model: BaseModel,
|
6 |
+
memory: BaseMemory
|
7 |
+
) -> None:
|
8 |
+
self.model = model
|
9 |
+
self.memory = memory
|
10 |
+
|
11 |
+
def respond(self, message: str) -> str:
|
12 |
+
raise NotImplementedError
|
core/chatbot/retrieval_chatbot.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_chatbot import BaseChatbot
|
2 |
+
from ..memory import BaseMemory, ChatMemory
|
3 |
+
from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever
|
4 |
+
from ..refiner import BaseRefiner, SimpleRefiner
|
5 |
+
from models import BaseModel, GPT4Model
|
6 |
+
from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt
|
7 |
+
import ast
|
8 |
+
from utils.image_encoder import encode_image
|
9 |
+
import asyncio
|
10 |
+
import time
|
11 |
+
|
12 |
+
class RetrievalChatbot(BaseChatbot):
|
13 |
+
def __init__(self,
|
14 |
+
model: BaseModel = None,
|
15 |
+
memory: BaseMemory = None,
|
16 |
+
retriever: BaseRetriever = None,
|
17 |
+
decomposer: BaseRefiner = None,
|
18 |
+
answerer: BaseRefiner = None,
|
19 |
+
summarizer: BaseRefiner = None,
|
20 |
+
) -> None:
|
21 |
+
self.model = model if model \
|
22 |
+
else GPT4Model()
|
23 |
+
self.memory = memory if memory \
|
24 |
+
else ChatMemory(sys_prompt=SummaryPrompt.content)
|
25 |
+
self.retriever = retriever if retriever \
|
26 |
+
else ChromaRetriever(pdf_dir="papers_all",
|
27 |
+
collection_name="pdfs",
|
28 |
+
split_args={"size": 2048, "overlap": 10},
|
29 |
+
embed_model=GPT4Model())
|
30 |
+
self.decomposer = decomposer if decomposer \
|
31 |
+
else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content)
|
32 |
+
self.answerer = answerer if answerer \
|
33 |
+
else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content)
|
34 |
+
self.summarizer = summarizer if summarizer \
|
35 |
+
else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content)
|
36 |
+
|
37 |
+
async def response(self, message: str, image_paths=None, return_logs=False) -> str:
|
38 |
+
time1 = time.time()
|
39 |
+
print("Query: {message}".format(message=message))
|
40 |
+
retrieved_reference=""
|
41 |
+
time_s = time.time()
|
42 |
+
results = self.retriever.retrieve(message)
|
43 |
+
refs, titles = results
|
44 |
+
for ref in refs:
|
45 |
+
retrieved_reference += "Related research: {ref}\n".format(ref=ref)
|
46 |
+
answerer_context = "Sub Question References: {retrieved_reference}\nQuestion: {message}\n".format(retrieved_reference=retrieved_reference, message=message)
|
47 |
+
answer = self.answerer.refine(answerer_context, self.memory, image_paths)
|
48 |
+
time_e = time.time()
|
49 |
+
|
50 |
+
#todo 记忆管理
|
51 |
+
if image_paths is None:
|
52 |
+
self.memory.append([{"role": "user", "content": [
|
53 |
+
{"type": "text", "text": f"{message}"},
|
54 |
+
]}, {"role": "assistant", "content": answer}])
|
55 |
+
else:
|
56 |
+
if not isinstance(image_paths, list):
|
57 |
+
image_paths = [image_paths]
|
58 |
+
memory_user = [{"type": "text", "text": f"{message}"},]
|
59 |
+
for image_path in image_paths:
|
60 |
+
memory_user.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}},)
|
61 |
+
self.memory.append([{"role": "user", "content": memory_user}, {"role": "assistant", "content": answer}])
|
62 |
+
print("="*20)
|
63 |
+
print(f"Final answer: {answer}".format(answer=answer))
|
64 |
+
|
65 |
+
return {
|
66 |
+
"answer": answer,
|
67 |
+
"titles": set(titles),
|
68 |
+
"logs": ""
|
69 |
+
}
|
core/memory/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base_memory import BaseMemory
|
2 |
+
from .chat_memory import ChatMemory
|
core/memory/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (278 Bytes). View file
|
|
core/memory/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (249 Bytes). View file
|
|
core/memory/__pycache__/base_memory.cpython-311.pyc
ADDED
Binary file (1.31 kB). View file
|
|
core/memory/__pycache__/base_memory.cpython-39.pyc
ADDED
Binary file (1.12 kB). View file
|
|
core/memory/__pycache__/chat_memory.cpython-311.pyc
ADDED
Binary file (1.66 kB). View file
|
|
core/memory/__pycache__/chat_memory.cpython-39.pyc
ADDED
Binary file (1.26 kB). View file
|
|
core/memory/base_memory.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseMemory:
|
2 |
+
def __init__(self) -> None:
|
3 |
+
raise NotImplementedError
|
4 |
+
|
5 |
+
def append(self, message: str) -> None:
|
6 |
+
raise NotImplementedError
|
7 |
+
|
8 |
+
def pop(self) -> None:
|
9 |
+
raise NotImplementedError
|
10 |
+
|
11 |
+
def clear(self) -> None:
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
def load(self) -> None:
|
15 |
+
raise NotImplementedError
|
16 |
+
|
17 |
+
def save(self) -> None:
|
18 |
+
raise NotImplementedError
|
core/memory/chat_memory.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_memory import BaseMemory
|
2 |
+
|
3 |
+
class ChatMemory(BaseMemory):
|
4 |
+
def __init__(self, sys_prompt = None) -> None:
|
5 |
+
self.sys_prompt = sys_prompt
|
6 |
+
self.messages = [{"role": "system", "content": sys_prompt}] if sys_prompt else []
|
7 |
+
|
8 |
+
def append(self, message: list) -> None:
|
9 |
+
# assert
|
10 |
+
self.messages += message
|
11 |
+
|
12 |
+
def pop(self) -> None:
|
13 |
+
self.messages.pop()
|
14 |
+
|
15 |
+
def clear(self) -> None:
|
16 |
+
self.messages = [{"role": "system", "content": self.sys_prompt}]
|
17 |
+
|
18 |
+
def load(self) -> None:
|
19 |
+
pass
|
20 |
+
|
21 |
+
def save(self) -> None:
|
22 |
+
pass
|
core/memory/plan_memory.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_memory import BaseMemory
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class Task:
|
7 |
+
def __init__(self, name: str, description: str):
|
8 |
+
self.name = name
|
9 |
+
self.description = description
|
10 |
+
|
11 |
+
class TaskChain:
|
12 |
+
def __init__(self, tasks: list):
|
13 |
+
self.tasks = tasks
|
14 |
+
def append(self, task: Task):
|
15 |
+
self.tasks.append(task)
|
16 |
+
def clear(self):
|
17 |
+
self.tasks = []
|
18 |
+
def __str__(self):
|
19 |
+
return "\n".join([f"{task.name}: {task.description}" for task in self.tasks])
|
20 |
+
|
21 |
+
class PlanMemory(BaseMemory):
|
22 |
+
def __init__(self, initial_message, initial_task) -> None:
|
23 |
+
self.messages = initial_message if initial_message else []
|
24 |
+
self.tasks = TaskChain(initial_task) if initial_task else TaskChain([])
|
25 |
+
|
26 |
+
def append(self, message: str) -> None:
|
27 |
+
self.messages.append(message)
|
28 |
+
#todo: parse message for tasks & add to task chain
|
29 |
+
self.tasks.append(Task("Task", message))
|
30 |
+
|
31 |
+
def clear(self) -> None:
|
32 |
+
self.messages = []
|
33 |
+
|
34 |
+
def load(self) -> None:
|
35 |
+
pass
|
36 |
+
|
37 |
+
def save(self) -> None:
|
38 |
+
pass
|
core/planner/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base_planner import BasePlanner
|
core/planner/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (219 Bytes). View file
|
|
core/planner/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (206 Bytes). View file
|
|
core/planner/__pycache__/base_planner.cpython-311.pyc
ADDED
Binary file (767 Bytes). View file
|
|
core/planner/__pycache__/base_planner.cpython-39.pyc
ADDED
Binary file (625 Bytes). View file
|
|
core/planner/base_planner.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class BasePlanner:
|
3 |
+
def __init__(self):
|
4 |
+
raise NotImplementedError
|
5 |
+
def plan(self, message: str) -> list[str]:
|
6 |
+
raise NotImplementedError
|
core/refiner/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base_refiner import BaseRefiner
|
2 |
+
from .simple_refiner import SimpleRefiner
|
core/refiner/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (287 Bytes). View file
|
|
core/refiner/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (258 Bytes). View file
|
|
core/refiner/__pycache__/base_refiner.cpython-311.pyc
ADDED
Binary file (938 Bytes). View file
|
|
core/refiner/__pycache__/base_refiner.cpython-39.pyc
ADDED
Binary file (738 Bytes). View file
|
|
core/refiner/__pycache__/simple_refiner.cpython-311.pyc
ADDED
Binary file (1.64 kB). View file
|
|
core/refiner/__pycache__/simple_refiner.cpython-39.pyc
ADDED
Binary file (895 Bytes). View file
|
|
core/refiner/base_refiner.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import BaseModel
|
2 |
+
class BaseRefiner:
|
3 |
+
def __init__(self,
|
4 |
+
sys_prompt: str,
|
5 |
+
model: BaseModel,
|
6 |
+
) -> None:
|
7 |
+
self.sys_prompt = sys_prompt
|
8 |
+
self.model = model
|
9 |
+
|
10 |
+
def refine(self, message: str) -> str:
|
11 |
+
raise NotImplementedError
|
core/refiner/recursive_refiner.py
ADDED
File without changes
|
core/refiner/simple_refiner.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import BaseModel
|
2 |
+
from .base_refiner import BaseRefiner
|
3 |
+
from utils.image_encoder import encode_image
|
4 |
+
import asyncio
|
5 |
+
|
6 |
+
class SimpleRefiner(BaseRefiner):
|
7 |
+
def __init__(self,
|
8 |
+
sys_prompt: str,
|
9 |
+
model: BaseModel,
|
10 |
+
) -> None:
|
11 |
+
BaseRefiner.__init__(self, sys_prompt=sys_prompt, model=model)
|
12 |
+
|
13 |
+
async def refine_async(self, message: str, memory, image_paths=None) -> str:
|
14 |
+
if memory is None:
|
15 |
+
memory = []
|
16 |
+
else:
|
17 |
+
memory = memory.messages[1:]
|
18 |
+
|
19 |
+
user_context = [{"role": "user", "content": [
|
20 |
+
{"type": "text", "text": f"{message}"},]}]
|
21 |
+
if image_paths:
|
22 |
+
if not isinstance(image_paths, list):
|
23 |
+
image_paths = [image_paths]
|
24 |
+
for image_path in image_paths:
|
25 |
+
user_context[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}})
|
26 |
+
context = [{"role": "system", "content": self.sys_prompt}] + memory + user_context
|
27 |
+
|
28 |
+
respond_task = asyncio.create_task(self.model.respond_async(context))
|
29 |
+
await respond_task
|
30 |
+
response = respond_task.result()
|
31 |
+
return response
|
32 |
+
|
33 |
+
def refine(self, message: str, memory, image_paths=None) -> str:
|
34 |
+
if memory is None:
|
35 |
+
memory = []
|
36 |
+
else:
|
37 |
+
memory = memory.messages[1:]
|
38 |
+
|
39 |
+
user_context = [{"role": "user", "content": [
|
40 |
+
{"type": "text", "text": f"{message}"},]}]
|
41 |
+
if image_paths:
|
42 |
+
if not isinstance(image_paths, list):
|
43 |
+
image_paths = [image_paths]
|
44 |
+
for image_path in image_paths:
|
45 |
+
user_context[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}})
|
46 |
+
context = [{"role": "system", "content": self.sys_prompt}] + memory + user_context
|
47 |
+
|
48 |
+
response = self.model.respond(context)
|
49 |
+
|
50 |
+
return response
|
core/retriever/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .base_retriever import BaseRetriever
|
2 |
+
from .chroma_retriever import ChromaRetriever
|
3 |
+
from .faiss_retriever import FaissRetriever
|