suneeln-duke commited on
Commit
77961ad
1 Parent(s): fe82b46
Files changed (5) hide show
  1. Dockerfile +22 -0
  2. main.py +78 -0
  3. requirements.txt +18 -0
  4. scripts/mongo_utils.py +145 -0
  5. scripts/rag_utils.py +135 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11.2
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ RUN apt update && apt install -y ffmpeg
10
+
11
+ RUN useradd -m -u 1000 user
12
+ USER user
13
+ ENV HOME=/home/user \
14
+ PATH=/home/user/.local/bin:$PATH
15
+
16
+ WORKDIR $HOME/app
17
+
18
+ COPY --chown=user . $HOME/app
19
+
20
+ ENV H2O_WAVE_LISTEN=":7860"
21
+ ENV H2O_WAVE_ADDRESS="http://127.0.0.1:7860"
22
+
main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+
3
+ from Flask import jsonify
4
+
5
+ from scripts import mongo_utils
6
+
7
+ from scripts import rag_utils
8
+
9
+ from dotenv import load_dotenv
10
+
11
+ import os
12
+
13
+ load_dotenv()
14
+
15
+ app = FastAPI
16
+
17
+ client = mongo_utils.connect_to_mongo()
18
+
19
+ print("Connected to MongoDB")
20
+
21
+ def captitalize_name(name):
22
+ name_split = name.split("_")
23
+
24
+ return " ".join([x.capitalize() for x in name_split])
25
+
26
+ @app.post('/summ')
27
+ def summarize(request: Request):
28
+ pdf_path = request.body()['pdf_path']
29
+
30
+ text = request.body()['text']
31
+
32
+ vs = mongo_utils.get_vs(pdf_path, client)
33
+
34
+ summary = rag_utils.summ(vs, text)
35
+
36
+ return {'summary': summary}
37
+
38
+ @app.post('/clf')
39
+ def classify(request: Request):
40
+
41
+ pdf_path = request.body()['pdf_path']
42
+
43
+ text = request.body()['text']
44
+
45
+ vs = mongo_utils.get_vs(pdf_path, client)
46
+
47
+ decision = rag_utils.clf_seq(vs, text).lower()
48
+
49
+ return jsonify({'decision': decision})
50
+
51
+ @app.post('/options')
52
+ def options(request: Request):
53
+ pdf_path = request.body()['pdf_path']
54
+
55
+ text = request.body()['text']
56
+ vs = mongo_utils.get_vs(pdf_path, client)
57
+
58
+ options = eval(rag_utils.gen_options(vs, text))
59
+
60
+ return jsonify({'options': options})
61
+
62
+ @app.post('/path')
63
+ def path(request: Request):
64
+ pdf_path = request.body()['pdf_path']
65
+
66
+ text = request.body()['text']
67
+
68
+ decision = request.body()['decision']
69
+
70
+ vs = mongo_utils.get_vs(pdf_path, client)
71
+
72
+ path = rag_utils.gen_path(vs, text, decision)
73
+
74
+ return jsonify({'path': path})
75
+
76
+
77
+ if __name__ == '__main__':
78
+ app.run()
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ flask==3.0.3
4
+ pypdf==4.2.0
5
+ pypdf2==3.0.1
6
+ pymongo==4.7.0
7
+ langchain==0.1.16
8
+ langchain_community==0.0.34
9
+ langchain_core==0.1.46
10
+ langchain_openai==0.0.2
11
+ openai>=0.26.2,<=1.6.1
12
+ pandas==2.2.2
13
+ scikit-learn==1.4.2
14
+ seaborn==0.13.2
15
+ matplotlib==3.8.4
16
+ python-dotenv==1.0.1
17
+ certifi
18
+ Flask-CORS
scripts/mongo_utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymongo import MongoClient
2
+
3
+ from langchain_openai import OpenAIEmbeddings
4
+
5
+ from langchain_community.vectorstores import MongoDBAtlasVectorSearch
6
+
7
+ from langchain_community.document_loaders import PyPDFLoader
8
+
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+
11
+ import sys, os, certifi
12
+
13
+ from dotenv import load_dotenv
14
+
15
+ from pathlib import Path
16
+
17
+ import PyPDF2
18
+
19
+ sys.path.append("..")
20
+
21
+ load_dotenv()
22
+
23
+ os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY")
24
+
25
+
26
+ def read_pages(pdf_file):
27
+ pages = []
28
+
29
+ reader = PyPDF2.PdfReader(pdf_file)
30
+
31
+ for page_number in range(len(reader.pages)):
32
+
33
+ page = reader.pages[page_number]
34
+
35
+ page_content = page.extract_text()
36
+
37
+ pages.append(page_content)
38
+
39
+ return pages
40
+
41
+ def connect_to_mongo():
42
+ ca = certifi.where()
43
+
44
+ client = MongoClient(os.environ.get("MONGO_URI"), tlsCAFile=ca)
45
+ # Send a ping to confirm a successful connection
46
+ try:
47
+ client.admin.command('ping')
48
+ print("Pinged your deployment. You successfully connected to MongoDB!")
49
+ except Exception as e:
50
+ print(e)
51
+ return client
52
+
53
+ def insert_pages(pdf_file, client=None):
54
+
55
+ pages = read_pages(pdf_file)
56
+
57
+ name = Path(pdf_file).stem
58
+
59
+ pages_dict = [{"text": page, "page": i, "source": name} for i, page in enumerate(pages)]
60
+
61
+ if not client:
62
+ client = connect_to_mongo()
63
+
64
+ pages_db = client[os.environ.get("MONGO_PAGES_DB")]
65
+
66
+ pages_collection = pages_db[f"{name}-pages"]
67
+
68
+ pages_collection.insert_many(pages_dict)
69
+
70
+ return list(pages_collection.find())
71
+
72
+
73
+ def get_pages(name, client=None):
74
+
75
+
76
+ if not client:
77
+ client = connect_to_mongo()
78
+
79
+ pages_db = client[os.environ.get("MONGO_PAGES_DB")]
80
+
81
+ if f"{name}-pages" not in pages_db.list_collection_names():
82
+ print("inserting pages")
83
+ return insert_pages(name, client=client)
84
+
85
+ else:
86
+ print("using existing page collection")
87
+ pages_collection = pages_db[f"{name}-pages"]
88
+
89
+ pages = list(pages_collection.find())
90
+
91
+ return pages
92
+
93
+ def insert_vs(pdf_file, client=None):
94
+ name = Path(pdf_file).stem
95
+
96
+ if not client:
97
+ client = connect_to_mongo()
98
+
99
+ vs_db = client[os.environ.get("MONGO_VS_DB")]
100
+
101
+ vs_collection = vs_db[f"{name}-vs"]
102
+
103
+ loader = PyPDFLoader(pdf_file)
104
+
105
+ data = loader.load()
106
+
107
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
108
+ chunk_overlap=100)
109
+ chunks = text_splitter.split_documents(data)
110
+
111
+ embeddings = OpenAIEmbeddings(openai_api_key=os.environ.get("OPENAI_API_KEY"),
112
+ disallowed_special=())
113
+
114
+ # Create embeddings in atlas vector store
115
+ vector_search = MongoDBAtlasVectorSearch.from_documents(
116
+ documents=chunks,
117
+ embedding= embeddings,
118
+ collection=vs_collection,
119
+ index_name=os.environ.get("MONGO_INDEX_DB")
120
+ )
121
+
122
+ return vector_search
123
+
124
+ def get_vs(name, client=None):
125
+
126
+ if not client:
127
+ client = connect_to_mongo()
128
+
129
+ vs_db = client[os.environ.get("MONGO_VS_DB")]
130
+
131
+ if f"{name}-vs" not in vs_db.list_collection_names():
132
+ print("inserting vs")
133
+ return insert_vs(name, client=client)
134
+
135
+ else:
136
+ print("using existing vs collection")
137
+ vector_search = MongoDBAtlasVectorSearch.from_connection_string(
138
+ os.environ.get("MONGO_URI"),
139
+ os.environ.get("MONGO_VS_DB") + "." + f"{name}-vs",
140
+ OpenAIEmbeddings(openai_api_key=os.environ.get("OPENAI_API_KEY"),
141
+ disallowed_special=()),
142
+ index_name=os.environ.get("MONGO_INDEX_DB"),
143
+ )
144
+
145
+ return vector_search
scripts/rag_utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai.chat_models import ChatOpenAI
2
+
3
+ from langchain_core.prompts import PromptTemplate
4
+
5
+ from langchain_core.runnables import RunnablePassthrough
6
+
7
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
8
+
9
+ import os, sys
10
+
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+
15
+ sys.path.append("..")
16
+
17
+ os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY")
18
+
19
+ def prep_config(vs):
20
+
21
+ retriever = vs.as_retriever(
22
+ search_type = "similarity",
23
+ search_kwargs = {"k": 3}
24
+ )
25
+
26
+ template = """Answer the question: {question} based only on the following context:
27
+ context: {context}
28
+ """
29
+
30
+ output_parser = JsonOutputParser()
31
+
32
+ prompt = PromptTemplate.from_template(template = template,
33
+ input_varaibles = ["context", "question"],
34
+ output_variables = ["answer"],)
35
+
36
+ output_parser = StrOutputParser()
37
+
38
+
39
+
40
+ model = ChatOpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY"),
41
+ model_name = 'gpt-4',
42
+ temperature=0.3)
43
+
44
+ def format_docs(docs):
45
+ return "\n\n".join(doc.page_content for doc in docs)
46
+
47
+ retrieval_chain = (
48
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
49
+ | prompt
50
+ | model
51
+ | output_parser
52
+ )
53
+
54
+ return retrieval_chain, output_parser
55
+
56
+
57
+ def gen_options(vs, text):
58
+
59
+ retrieval_chain, output_parser = prep_config(vs)
60
+
61
+ query = f"""
62
+ Act as the author of a Choose Your Own Adventure Book. This book is special as it is based on existing material.
63
+ Now, as with any choose your own adventure book, you'll have to generate decision paths based on the given story excerpt
64
+ Your job is to generate 4 decision paths for the given point in the story.
65
+ One among the 4 decision paths should be the original path, the other 3 should deviate from the original path in a sensible manner.
66
+ The decision paths should be generated in a way that they are coherent with the existing story.
67
+ Limit each decision path to a succint sentence.
68
+ Return the 4 decision paths as a list of strings.
69
+
70
+ Story Excerpt: {text}
71
+
72
+ """
73
+
74
+ response = retrieval_chain.invoke(query)
75
+
76
+ return response
77
+
78
+ def gen_path(vs, text, decision):
79
+
80
+ retrieval_chain, output_parser = prep_config(vs)
81
+
82
+ query = f"""
83
+ Act as the author of a Choose Your Own Adventure Book. This book is special as it is based on existing material.
84
+ Now, as with any choose your own adventure book, you'll have to generate new story paths based on a relevant excerpt of the story and the decision taken.
85
+ Your job is to generate the next part of the story based on the given part of the story and the decision taken.
86
+ The new story path should be coherent with the existing story, and should be around 6-8 sentences.
87
+ If the decision string is empty, your task is just to generate the next part of the story based on the given part of the story.
88
+ Return the new story path as a string.
89
+
90
+ Story Excerpt: {text}
91
+
92
+ Decision: {decision}
93
+ """
94
+
95
+ response = retrieval_chain.invoke(query)
96
+
97
+ return output_parser.parse(response)
98
+
99
+ def clf_seq(vs, text):
100
+
101
+ retrieval_chain, output_parser = prep_config(vs)
102
+
103
+ query = f"""
104
+ Classify whether the given chunk involves a decision that will effect the story or not.
105
+
106
+ A decision is defined as when the character goes about making a choice between two or more options.
107
+ The decision should be significant enough to affect the story in a major way.
108
+ It doesn't really involve emotions, feelings or thoughts, but what the character does, or what happens to them.
109
+ This involes interactions between characters, or the character and the environment.
110
+ What isn't a decision is chunks describing the setting, or the character's thoughts or feelings.
111
+
112
+ Return the answer as the corresponding decision label "yes" or "no"
113
+
114
+ {text}
115
+ """
116
+
117
+ response = retrieval_chain.invoke(query)
118
+
119
+ return output_parser.parse(response)
120
+
121
+ def summ(vs, text):
122
+
123
+ retrieval_chain, output_parser = prep_config(vs)
124
+
125
+ query = f"""
126
+ Summarize the given text in a narrative manner as a part of storytelling.
127
+ The summary should be around 3-4 sentences and should be coherent with the existing story.
128
+
129
+ Return the summary as a string.
130
+ {text}
131
+ """
132
+
133
+ response = retrieval_chain.invoke(query)
134
+
135
+ return output_parser.parse(response)