dh-mc commited on
Commit
90abc4b
·
1 Parent(s): d380674

refine summarize chain

Browse files
Files changed (2) hide show
  1. app_modules/llm_summarize_chain.py +48 -1
  2. summarize.py +11 -9
app_modules/llm_summarize_chain.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  from typing import List, Optional
 
3
 
4
  from langchain.chains.base import Chain
5
  from langchain.chains.summarize import load_summarize_chain
@@ -7,12 +8,58 @@ from langchain.chains.summarize import load_summarize_chain
7
  from app_modules.llm_inference import LLMInference
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  class SummarizeChain(LLMInference):
11
  def __init__(self, llm_loader):
12
  super().__init__(llm_loader)
13
 
14
  def create_chain(self) -> Chain:
15
- chain = load_summarize_chain(self.llm_loader.llm, chain_type="refine")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  return chain
17
 
18
  def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
 
1
  import os
2
  from typing import List, Optional
3
+ from langchain import PromptTemplate
4
 
5
  from langchain.chains.base import Chain
6
  from langchain.chains.summarize import load_summarize_chain
 
8
  from app_modules.llm_inference import LLMInference
9
 
10
 
11
+ def get_llama_2_prompt_template(instruction):
12
+ B_INST, E_INST = "[INST]", "[/INST]"
13
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
14
+
15
+ system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the text to get context"
16
+
17
+ SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
18
+ prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
19
+ return prompt_template
20
+
21
+
22
  class SummarizeChain(LLMInference):
23
  def __init__(self, llm_loader):
24
  super().__init__(llm_loader)
25
 
26
  def create_chain(self) -> Chain:
27
+ use_llama_2_prompt_template = (
28
+ os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
29
+ )
30
+ prompt_template = """Write a concise summary of the following:
31
+ {text}
32
+ CONCISE SUMMARY:"""
33
+
34
+ if use_llama_2_prompt_template:
35
+ prompt_template = get_llama_2_prompt_template(prompt_template)
36
+ prompt = PromptTemplate.from_template(prompt_template)
37
+
38
+ refine_template = (
39
+ "Your job is to produce a final summary\n"
40
+ "We have provided an existing summary up to a certain point: {existing_answer}\n"
41
+ "We have the opportunity to refine the existing summary"
42
+ "(only if needed) with some more context below.\n"
43
+ "------------\n"
44
+ "{text}\n"
45
+ "------------\n"
46
+ "Given the new context, refine the original summary."
47
+ "If the context isn't useful, return the original summary."
48
+ )
49
+
50
+ if use_llama_2_prompt_template:
51
+ refine_template = get_llama_2_prompt_template(refine_template)
52
+ refine_prompt = PromptTemplate.from_template(refine_template)
53
+
54
+ chain = load_summarize_chain(
55
+ llm=self.llm_loader.llm,
56
+ chain_type="refine",
57
+ question_prompt=prompt,
58
+ refine_prompt=refine_prompt,
59
+ return_intermediate_steps=True,
60
+ input_key="input_documents",
61
+ output_key="output_text",
62
+ )
63
  return chain
64
 
65
  def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
summarize.py CHANGED
@@ -15,17 +15,16 @@ from app_modules.init import app_init, get_device_types
15
  from app_modules.llm_summarize_chain import SummarizeChain
16
 
17
 
18
- def load_documents(source_pdfs_path, urls) -> List:
19
  loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
20
  documents = loader.load()
21
- if urls is not None and len(urls) > 0:
22
  for doc in documents:
23
- source = doc.metadata["source"]
24
- filename = source.split("/")[-1]
25
- for url in urls:
26
- if url.endswith(filename):
27
- doc.metadata["url"] = url
28
- break
29
  return documents
30
 
31
 
@@ -43,8 +42,11 @@ source_pdfs_path = (
43
  )
44
  chunk_size = sys.argv[2] if len(sys.argv) > 2 else os.environ.get("CHUNCK_SIZE")
45
  chunk_overlap = sys.argv[3] if len(sys.argv) > 3 else os.environ.get("CHUNK_OVERLAP")
 
 
 
46
 
47
- sources = load_documents(source_pdfs_path, None)
48
 
49
  print(f"Splitting {len(sources)} PDF pages in to chunks ...")
50
 
 
15
  from app_modules.llm_summarize_chain import SummarizeChain
16
 
17
 
18
+ def load_documents(source_pdfs_path, keep_page_info) -> List:
19
  loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
20
  documents = loader.load()
21
+ if not keep_page_info:
22
  for doc in documents:
23
+ if doc is not documents[0]:
24
+ documents[0].page_content = (
25
+ documents[0].page_content + "\n" + doc.page_content
26
+ )
27
+ documents = [documents[0]]
 
28
  return documents
29
 
30
 
 
42
  )
43
  chunk_size = sys.argv[2] if len(sys.argv) > 2 else os.environ.get("CHUNCK_SIZE")
44
  chunk_overlap = sys.argv[3] if len(sys.argv) > 3 else os.environ.get("CHUNK_OVERLAP")
45
+ keep_page_info = (
46
+ sys.argv[3] if len(sys.argv) > 3 else os.environ.get("KEEP_PAGE_INFO")
47
+ ) == "true"
48
 
49
+ sources = load_documents(source_pdfs_path, keep_page_info)
50
 
51
  print(f"Splitting {len(sources)} PDF pages in to chunks ...")
52