gamingflexer commited on
Commit
3b5f78b
1 Parent(s): 7cbc824

Add scrapper extractor module

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. src/scrapper/extractor.py +58 -0
.gitignore CHANGED
@@ -161,3 +161,4 @@ cython_debug/
161
  src/flagged/log.csv
162
  .vscode/PythonImportHelper-v2-Completion.json
163
  notebooks/*.pdf
 
 
161
  src/flagged/log.csv
162
  .vscode/PythonImportHelper-v2-Completion.json
163
  notebooks/*.pdf
164
+ notebooks/notebooks/papers/*.jsonl
src/scrapper/extractor.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.chat_models import ChatOpenAI
2
+ from langchain import PromptTemplate, LLMChain
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ import tiktoken
5
+ from typing import Union
6
+ from config import OPENAI_API_KEY
7
+ import os
8
+
9
+ def init_extractor(
10
+ template: str,
11
+ openai_api_key: Union[str, None] = None,
12
+ max_tokens: int = 1000,
13
+ chunk_size: int = 300,
14
+ chunk_overlap: int = 40
15
+ ):
16
+ if openai_api_key is None and 'OPENAI_API_KEY' not in os.environ:
17
+ raise Exception('No OpenAI API key provided')
18
+ openai_api_key = openai_api_key or os.environ['OPENAI_API_KEY']
19
+ # instantiate the OpenAI API wrapper
20
+ llm = ChatOpenAI(
21
+ model_name='gpt-3.5-turbo-16k',
22
+ openai_api_key=OPENAI_API_KEY,
23
+ max_tokens=max_tokens,
24
+ temperature=0.0
25
+ )
26
+ # initialize prompt template
27
+ prompt = PromptTemplate(
28
+ template=template,
29
+ input_variables=['refs']
30
+ )
31
+ # instantiate the LLMChain extractor model
32
+ extractor = LLMChain(
33
+ prompt=prompt,
34
+ llm=llm
35
+ )
36
+
37
+ text_splitter = tiktoken_splitter(
38
+ chunk_size=chunk_size,
39
+ chunk_overlap=chunk_overlap
40
+ )
41
+ return extractor, text_splitter
42
+
43
+ def tiktoken_splitter(chunk_size=300, chunk_overlap=40):
44
+ tokenizer = tiktoken.get_encoding('p50k_base')
45
+ # create length function
46
+ def len_fn(text):
47
+ tokens = tokenizer.encode(
48
+ text, disallowed_special=()
49
+ )
50
+ return len(tokens)
51
+ # initialize the text splitter
52
+ text_splitter = RecursiveCharacterTextSplitter(
53
+ chunk_size=chunk_size,
54
+ chunk_overlap=chunk_overlap,
55
+ length_function=len_fn,
56
+ separators=["\n\n", "\n", " ", ""]
57
+ )
58
+ return text_splitter