momegas commited on
Commit
033ca0b
1 Parent(s): f145d3d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +114 -0
  2. requirements.txt +141 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ast import List
2
+ from langchain.document_loaders import DirectoryLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ import dotenv
5
+ from langchain.prompts import PromptTemplate
6
+ import gradio as gr
7
+ from langchain import PromptTemplate, LLMChain
8
+ import requests
9
+ from fastembed.embedding import FlagEmbedding as Embedding
10
+ import numpy as np
11
+ import os
12
+
13
+
14
+ dotenv.load_dotenv()
15
+
16
+ api_token = os.environ.get("API_TOKEN")
17
+ API_URL = "https://vpb8x4glbmizmiya.eu-west-1.aws.endpoints.huggingface.cloud"
18
+ headers = {
19
+ "Authorization": f"Bearer {api_token}",
20
+ "Content-Type": "application/json",
21
+ }
22
+
23
+
24
+ def query(payload):
25
+ response = requests.post(API_URL, headers=headers, json=payload)
26
+ return response.json()
27
+
28
+
29
+ def get_top_k(query_embedding, embeddings, documents, k=3):
30
+ # use numpy to calculate the cosine similarity between the query and the documents
31
+ scores = np.dot(embeddings, query_embedding)
32
+ # sort the scores in descending order
33
+ sorted_scores = np.argsort(scores)[::-1]
34
+ # print the top 5
35
+ result = []
36
+ for i in range(k):
37
+ print(f"Rank {i+1}: {documents[sorted_scores[i]]}", "\n")
38
+ result.append(documents[sorted_scores[i]])
39
+
40
+ return result
41
+
42
+
43
+ prompt_template = """
44
+ You are the helpful assistant representing the company Philip Morris.
45
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
46
+ "Use the following pieces of context to answer the question at the end.
47
+
48
+ Context:
49
+ {context}
50
+
51
+ Question: {question}
52
+ Answer:
53
+ """
54
+
55
+
56
+ PROMPT = PromptTemplate(
57
+ template=prompt_template, input_variables=["context", "question"]
58
+ )
59
+
60
+ loader = DirectoryLoader("./documents", glob="**/*.txt", show_progress=True)
61
+ docs = loader.load()
62
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=150)
63
+ texts = text_splitter.split_documents(docs)
64
+
65
+ embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512)
66
+ embeddings = list(embedding_model.embed([text.page_content for text in texts]))
67
+
68
+ with gr.Blocks() as demo:
69
+ chatbot = gr.Chatbot()
70
+ msg = gr.Textbox()
71
+ clear = gr.ClearButton([msg, chatbot])
72
+
73
+ def respond(message, chat_history):
74
+ message_embedding = list(embedding_model.embed([message]))[0]
75
+ result_docs = get_top_k(message_embedding, embeddings, texts, k=3)
76
+
77
+ human_message = HumanMessage(
78
+ content=PROMPT.format(context=result_docs, question=message)
79
+ )
80
+
81
+ print("Question: ", human_message)
82
+ output = query(
83
+ {
84
+ "inputs": human_message.content,
85
+ "parameters": {
86
+ "temperature": 0.9,
87
+ "top_p": 0.95,
88
+ "repetition_penalty": 1.2,
89
+ "top_k": 50,
90
+ "truncate": 1000,
91
+ "max_new_tokens": 1024,
92
+ },
93
+ }
94
+ )
95
+ print("Response: ", output, "\n")
96
+ bot_message = ""
97
+
98
+ if output[0]["generated_text"]:
99
+ bot_message = f"""{output[0]["generated_text"]}
100
+
101
+ Sources:
102
+ {[doc.page_content for doc in result_docs]}
103
+ """
104
+ else:
105
+ bot_message = f'There was an error: {output[0]["error"]}'
106
+
107
+ chat_history.append((message, bot_message))
108
+ return "", chat_history
109
+
110
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
111
+
112
+
113
+ if __name__ == "__main__":
114
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.8.5
3
+ aiosignal==1.3.1
4
+ altair==5.1.1
5
+ annotated-types==0.5.0
6
+ antlr4-python3-runtime==4.9.3
7
+ anyio==3.7.1
8
+ async-timeout==4.0.3
9
+ attrs==23.1.0
10
+ backoff==2.2.1
11
+ bcrypt==4.0.1
12
+ beautifulsoup4==4.12.2
13
+ certifi==2023.7.22
14
+ cffi==1.16.0
15
+ chardet==5.2.0
16
+ charset-normalizer==3.3.0
17
+ chroma-hnswlib==0.7.3
18
+ chromadb==0.4.13
19
+ click==8.1.7
20
+ coloredlogs==15.0.1
21
+ contourpy==1.1.1
22
+ cryptography==41.0.4
23
+ ctransformers==0.2.27
24
+ cycler==0.12.0
25
+ dataclasses-json==0.6.1
26
+ effdet==0.4.1
27
+ emoji==2.8.0
28
+ fastapi==0.103.2
29
+ fastembed==0.1.1
30
+ ffmpy==0.3.1
31
+ filelock==3.12.4
32
+ filetype==1.2.0
33
+ flatbuffers==23.5.26
34
+ fonttools==4.43.0
35
+ frozenlist==1.4.0
36
+ fsspec==2023.9.2
37
+ gradio==3.45.2
38
+ gradio_client==0.5.3
39
+ h11==0.14.0
40
+ httpcore==0.18.0
41
+ httptools==0.6.0
42
+ httpx==0.25.0
43
+ huggingface-hub==0.16.4
44
+ humanfriendly==10.0
45
+ idna==3.4
46
+ importlib-resources==6.1.0
47
+ iopath==0.1.10
48
+ Jinja2==3.1.2
49
+ joblib==1.3.2
50
+ jsonpatch==1.33
51
+ jsonpointer==2.4
52
+ jsonschema==4.19.1
53
+ jsonschema-specifications==2023.7.1
54
+ kiwisolver==1.4.5
55
+ langchain==0.0.305
56
+ langdetect==1.0.9
57
+ langsmith==0.0.41
58
+ layoutparser==0.3.4
59
+ lxml==4.9.3
60
+ MarkupSafe==2.1.3
61
+ marshmallow==3.20.1
62
+ matplotlib==3.8.0
63
+ monotonic==1.6
64
+ mpmath==1.3.0
65
+ multidict==6.0.4
66
+ mypy-extensions==1.0.0
67
+ networkx==3.2
68
+ nltk==3.8.1
69
+ numexpr==2.8.7
70
+ numpy==1.26.0
71
+ omegaconf==2.3.0
72
+ onnx==1.14.1
73
+ onnxruntime==1.16.0
74
+ openai==0.28.1
75
+ opencv-python==4.8.1.78
76
+ orjson==3.9.7
77
+ overrides==7.4.0
78
+ packaging==23.2
79
+ pandas==2.1.1
80
+ pdf2image==1.16.3
81
+ pdfminer.six==20221105
82
+ pdfplumber==0.10.2
83
+ Pillow==10.0.1
84
+ portalocker==2.8.2
85
+ posthog==3.0.2
86
+ protobuf==4.24.3
87
+ pulsar-client==3.3.0
88
+ py-cpuinfo==9.0.0
89
+ pycocotools==2.0.7
90
+ pycparser==2.21
91
+ pydantic==2.4.2
92
+ pydantic_core==2.10.1
93
+ pydub==0.25.1
94
+ pyparsing==3.1.1
95
+ pypdfium2==4.21.0
96
+ PyPika==0.48.9
97
+ pytesseract==0.3.10
98
+ python-dateutil==2.8.2
99
+ python-dotenv==1.0.0
100
+ python-iso639==2023.6.15
101
+ python-magic==0.4.27
102
+ python-multipart==0.0.6
103
+ pytz==2023.3.post1
104
+ PyYAML==6.0.1
105
+ rapidfuzz==3.4.0
106
+ referencing==0.30.2
107
+ regex==2023.8.8
108
+ requests==2.31.0
109
+ rpds-py==0.10.3
110
+ safetensors==0.4.0
111
+ scipy==1.11.3
112
+ semantic-version==2.10.0
113
+ six==1.16.0
114
+ sniffio==1.3.0
115
+ soupsieve==2.5
116
+ SQLAlchemy==2.0.21
117
+ starlette==0.27.0
118
+ sympy==1.12
119
+ tabulate==0.9.0
120
+ tenacity==8.2.3
121
+ tiktoken==0.5.1
122
+ timm==0.9.7
123
+ tokenizers==0.14.1
124
+ toolz==0.12.0
125
+ torch==2.1.0
126
+ torchvision==0.16.0
127
+ tqdm==4.66.1
128
+ transformers==4.34.1
129
+ typer==0.9.0
130
+ typing-inspect==0.9.0
131
+ typing_extensions==4.8.0
132
+ tzdata==2023.3
133
+ unstructured==0.10.18
134
+ unstructured-inference==0.5.31
135
+ unstructured.pytesseract==0.3.12
136
+ urllib3==2.0.5
137
+ uvicorn==0.23.2
138
+ uvloop==0.17.0
139
+ watchfiles==0.20.0
140
+ websockets==11.0.3
141
+ yarl==1.9.2