refine summarize chain
Browse files- app_modules/llm_summarize_chain.py +48 -1
- summarize.py +11 -9
app_modules/llm_summarize_chain.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
from typing import List, Optional
|
|
|
3 |
|
4 |
from langchain.chains.base import Chain
|
5 |
from langchain.chains.summarize import load_summarize_chain
|
@@ -7,12 +8,58 @@ from langchain.chains.summarize import load_summarize_chain
|
|
7 |
from app_modules.llm_inference import LLMInference
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
class SummarizeChain(LLMInference):
|
11 |
def __init__(self, llm_loader):
|
12 |
super().__init__(llm_loader)
|
13 |
|
14 |
def create_chain(self) -> Chain:
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
return chain
|
17 |
|
18 |
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
|
|
|
1 |
import os
|
2 |
from typing import List, Optional
|
3 |
+
from langchain import PromptTemplate
|
4 |
|
5 |
from langchain.chains.base import Chain
|
6 |
from langchain.chains.summarize import load_summarize_chain
|
|
|
8 |
from app_modules.llm_inference import LLMInference
|
9 |
|
10 |
|
11 |
+
def get_llama_2_prompt_template(instruction):
|
12 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
13 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
14 |
+
|
15 |
+
system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the text to get context"
|
16 |
+
|
17 |
+
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
|
18 |
+
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
|
19 |
+
return prompt_template
|
20 |
+
|
21 |
+
|
22 |
class SummarizeChain(LLMInference):
|
23 |
def __init__(self, llm_loader):
|
24 |
super().__init__(llm_loader)
|
25 |
|
26 |
def create_chain(self) -> Chain:
|
27 |
+
use_llama_2_prompt_template = (
|
28 |
+
os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
29 |
+
)
|
30 |
+
prompt_template = """Write a concise summary of the following:
|
31 |
+
{text}
|
32 |
+
CONCISE SUMMARY:"""
|
33 |
+
|
34 |
+
if use_llama_2_prompt_template:
|
35 |
+
prompt_template = get_llama_2_prompt_template(prompt_template)
|
36 |
+
prompt = PromptTemplate.from_template(prompt_template)
|
37 |
+
|
38 |
+
refine_template = (
|
39 |
+
"Your job is to produce a final summary\n"
|
40 |
+
"We have provided an existing summary up to a certain point: {existing_answer}\n"
|
41 |
+
"We have the opportunity to refine the existing summary"
|
42 |
+
"(only if needed) with some more context below.\n"
|
43 |
+
"------------\n"
|
44 |
+
"{text}\n"
|
45 |
+
"------------\n"
|
46 |
+
"Given the new context, refine the original summary."
|
47 |
+
"If the context isn't useful, return the original summary."
|
48 |
+
)
|
49 |
+
|
50 |
+
if use_llama_2_prompt_template:
|
51 |
+
refine_template = get_llama_2_prompt_template(refine_template)
|
52 |
+
refine_prompt = PromptTemplate.from_template(refine_template)
|
53 |
+
|
54 |
+
chain = load_summarize_chain(
|
55 |
+
llm=self.llm_loader.llm,
|
56 |
+
chain_type="refine",
|
57 |
+
question_prompt=prompt,
|
58 |
+
refine_prompt=refine_prompt,
|
59 |
+
return_intermediate_steps=True,
|
60 |
+
input_key="input_documents",
|
61 |
+
output_key="output_text",
|
62 |
+
)
|
63 |
return chain
|
64 |
|
65 |
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
|
summarize.py
CHANGED
@@ -15,17 +15,16 @@ from app_modules.init import app_init, get_device_types
|
|
15 |
from app_modules.llm_summarize_chain import SummarizeChain
|
16 |
|
17 |
|
18 |
-
def load_documents(source_pdfs_path,
|
19 |
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
|
20 |
documents = loader.load()
|
21 |
-
if
|
22 |
for doc in documents:
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
break
|
29 |
return documents
|
30 |
|
31 |
|
@@ -43,8 +42,11 @@ source_pdfs_path = (
|
|
43 |
)
|
44 |
chunk_size = sys.argv[2] if len(sys.argv) > 2 else os.environ.get("CHUNCK_SIZE")
|
45 |
chunk_overlap = sys.argv[3] if len(sys.argv) > 3 else os.environ.get("CHUNK_OVERLAP")
|
|
|
|
|
|
|
46 |
|
47 |
-
sources = load_documents(source_pdfs_path,
|
48 |
|
49 |
print(f"Splitting {len(sources)} PDF pages in to chunks ...")
|
50 |
|
|
|
15 |
from app_modules.llm_summarize_chain import SummarizeChain
|
16 |
|
17 |
|
18 |
+
def load_documents(source_pdfs_path, keep_page_info) -> List:
|
19 |
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
|
20 |
documents = loader.load()
|
21 |
+
if not keep_page_info:
|
22 |
for doc in documents:
|
23 |
+
if doc is not documents[0]:
|
24 |
+
documents[0].page_content = (
|
25 |
+
documents[0].page_content + "\n" + doc.page_content
|
26 |
+
)
|
27 |
+
documents = [documents[0]]
|
|
|
28 |
return documents
|
29 |
|
30 |
|
|
|
42 |
)
|
43 |
chunk_size = sys.argv[2] if len(sys.argv) > 2 else os.environ.get("CHUNCK_SIZE")
|
44 |
chunk_overlap = sys.argv[3] if len(sys.argv) > 3 else os.environ.get("CHUNK_OVERLAP")
|
45 |
+
keep_page_info = (
|
46 |
+
sys.argv[3] if len(sys.argv) > 3 else os.environ.get("KEEP_PAGE_INFO")
|
47 |
+
) == "true"
|
48 |
|
49 |
+
sources = load_documents(source_pdfs_path, keep_page_info)
|
50 |
|
51 |
print(f"Splitting {len(sources)} PDF pages in to chunks ...")
|
52 |
|