Ashmi Banerjee commited on
Commit
4b722ec
·
1 Parent(s): 2fdacb0

first draft, with the gcp bucket

Browse files
.gitignore CHANGED
@@ -182,4 +182,8 @@ gradio_cached_examples/
182
  models/__pycache__/
183
  setup/__pycache__/
184
  setup/gcp_default_creds.json
185
- .config/*.json
 
 
 
 
 
182
  models/__pycache__/
183
  setup/__pycache__/
184
  setup/gcp_default_creds.json
185
+ .config/*.json
186
+ gradio_cached_examples/
187
+ src/gradio_cached_examples/
188
+ database/
189
+ european-city-data/
requirements.txt CHANGED
@@ -1,10 +1,12 @@
1
  sentence-transformers
2
  gradio
3
  gradio_client
4
- pymongo==4.6.2
5
  python-dotenv
6
  google-cloud-aiplatform
7
  google-cloud
8
- vertexai==1.43.0
9
  huggingface_hub
10
- certifi==2021.5.30
 
 
 
 
1
  sentence-transformers
2
  gradio
3
  gradio_client
 
4
  python-dotenv
5
  google-cloud-aiplatform
6
  google-cloud
7
+ vertexai==1.64.0
8
  huggingface_hub
9
+ certifi==2024.7.4
10
+ python-dotenv==1.0.1
11
+ bitsandbytes
12
+ accelerate
src/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.vectordb.vectordb import *
2
+ from src.vectordb.helpers import *
3
+ from src.vectordb.lancedb_init import *
4
+
5
+ from src.sustainability.s_fairness import *
6
+ from src.information_retrieval.info_retrieval import *
7
+ from src.augmentation.prompt_generation import *
8
+ from src.text_generation.model_init import *
9
+ from src.text_generation.text_generation import *
10
+
11
+ from src.data_directories import *
src/app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import gradio as gr
3
+ import os, sys
4
+ sys.path.append("./src")
5
+ print(os.getcwd())
6
+ from src.pipeline import pipeline
7
+
8
+
9
+ def clear():
10
+ return None, None, None
11
+
12
+
13
+ def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
14
+ temp: Optional[float] = 0.49):
15
+ if is_sustainable:
16
+ sustainability = 1
17
+ else:
18
+ sustainability = 0
19
+ pipeline_response = pipeline(
20
+ query=query_text,
21
+ model_name=model_name,
22
+ sustainability= sustainability
23
+ )
24
+ return pipeline_response
25
+
26
+
27
+ examples = [["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
28
+ "local cuisines to try?", "GPT-4"],
29
+ ["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "Gemini-1.0-pro"],
30
+ ["Suggest some cities that can be visited from London and are very rich in history and culture.",
31
+ "Gemini-1.0-pro"],
32
+ ]
33
+
34
+ with gr.Blocks() as demo:
35
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>🇪🇺 Euro City Recommender using Gemini & Gemma 🇪🇺</h1><br><h3>Gemini
36
+ & Gemma Sprints 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the compatibility of
37
+ Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 Pro</b>
38
+ models through HuggingFace and VertexAI, respectively, to generate travel recommendations. This early version (read
39
+ quick and dirty implementation) aims to see if functionalities work smoothly. It relies on Wikipedia abstracts
40
+ from 160 European cities to provide answers to your questions. Please be kind with it, as it's a work in progress!
41
+ </p> <br>Google Cloud credits are provided for this project. </p>
42
+ """)
43
+
44
+ with gr.Group():
45
+ query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
46
+ sustainable = gr.Checkbox(label="Sustainable", info="If you want sustainable recommendations for "
47
+ "hidden gems?")
48
+ model = gr.Dropdown(
49
+ ["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more "
50
+ "models "
51
+ "later!",
52
+ )
53
+ output = gr.Textbox(label="Generated Results", lines=4)
54
+
55
+ with gr.Accordion("Settings", open=False):
56
+ max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
57
+ interactive=True,
58
+ visible=True, info="The maximum number of output tokens")
59
+ temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
60
+ interactive=True,
61
+ visible=True, info="The value used to module the logits distribution")
62
+ with gr.Group():
63
+ with gr.Row():
64
+ submit_btn = gr.Button("Submit", variant="primary")
65
+ clear_btn = gr.Button("Clear", variant="secondary")
66
+ cancel_btn = gr.Button("Cancel", variant="stop")
67
+ submit_btn.click(generate_text, inputs=[query, model, sustainable], outputs=[output])
68
+ clear_btn.click(clear, inputs=[], outputs=[query, model, output])
69
+ cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
70
+
71
+ gr.Markdown("## Examples")
72
+ gr.Examples(
73
+ examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
74
+ cache_examples=True,
75
+ )
76
+
77
+ if __name__ == "__main__":
78
+ demo.launch(show_api=False)
src/augmentation/__init__.py ADDED
File without changes
src/augmentation/prompt_generation.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from information_retrieval import info_retrieval as ir
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
6
+
7
+
8
+ def generate_prompt(query, context, template=None):
9
+ """
10
+ Function that generates the prompt given the user query and retrieved context. A specific prompt template will be
11
+ used if provided, otherwise the default base_prompt template is used.
12
+
13
+ Args:
14
+ - query: str
15
+ - context: list[dict]
16
+ - template: str
17
+
18
+ """
19
+
20
+ if template:
21
+ SYS_PROMPT = template
22
+ else:
23
+ SYS_PROMPT = """You are an AI recommendation system. Your task is to recommend cities in Europe for travel
24
+ based on the user's question. You should use the provided contexts to suggest a list of the 3 best cities
25
+ that are best suited to the user's question, as well as the month of travel. If the user has already provided
26
+ the month of travel in the question, use the same month; otherwise, provide the ideal month of travel. Each
27
+ recommendation should also contain an explanation of why it is being recommended, based on the context. Your
28
+ answer must begin with "I recommend " followed by the city name and why you recommended it. Your answers are
29
+ correct, high-quality, and written by a domain expert. If the provided context does not contain the answer,
30
+ simply state, "The provided context does not have the answer." """
31
+
32
+ USER_PROMPT = """ Question: {} Which city do you recommend and why?
33
+
34
+ Context: Here are the options: {}
35
+
36
+ Answer:
37
+
38
+ """
39
+
40
+ formatted_prompt = f"{USER_PROMPT.format(query, context)}"
41
+ messages = [
42
+ {"role": "system", "content": SYS_PROMPT},
43
+ {"role": "user", "content": formatted_prompt}
44
+ ]
45
+
46
+ return messages
47
+
48
+
49
+ def format_context(context):
50
+ """
51
+ Function that formats the retrieved context in a way that is easy for the LLM to understand.
52
+
53
+ Args:
54
+ - context: list[dict]; retrieved context
55
+
56
+ """
57
+
58
+ formatted_context = ''
59
+
60
+ for i, (city, info) in enumerate(context.items()):
61
+
62
+ text = f"Option {i + 1}: {city} is a city in {info['country']}."
63
+ info_text = f"Here is some information about the city. {info['text']}"
64
+
65
+ attractions_text = "Here are some attractions: "
66
+ att_flag = 0
67
+ restaurants_text = "Here are some places to eat/drink: "
68
+ rest_flag = 0
69
+
70
+ hotels_text = "Here are some hotels: "
71
+ hotel_flag = 0
72
+
73
+ if len(info['listings']):
74
+ for listing in info['listings']:
75
+ if listing['type'] in ['see', 'do', 'go', 'view']:
76
+ att_flag = 1
77
+ attractions_text += f"{listing['name']} ({listing['description']}), "
78
+ elif listing['type'] in ['eat', 'drink']:
79
+ rest_flag = 1
80
+ restaurants_text += f"{listing['name']} ({listing['description']}), "
81
+ else:
82
+ hotel_flag = 1
83
+ hotels_text += f"{listing['name']} ({listing['description']}), "
84
+
85
+ # If we add sustainability in the end then it could get truncated because of context window
86
+ if "sustainability" in info:
87
+ if info['sustainability']['month'] == 'No data available':
88
+ sfairness_text = "This city has no sustainability (or s-fairness) score available."
89
+
90
+ else:
91
+ sfairness_text = f"The sustainability (or s-fairness) score for {city} in {info['sustainability']['month']} is {info['sustainability']['s-fairness']}. \n "
92
+
93
+ text += sfairness_text
94
+
95
+ text += info_text
96
+
97
+ if att_flag:
98
+ text += f"\n{attractions_text}"
99
+
100
+ if rest_flag:
101
+ text += f"\n{restaurants_text}"
102
+
103
+ if hotel_flag:
104
+ text += f"\n{hotels_text}"
105
+
106
+ formatted_context += text + "\n\n "
107
+
108
+ return formatted_context
109
+
110
+
111
+ def augment_prompt(query, context, **params):
112
+ """
113
+ Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
114
+ retrieved context, which can be passed to the LLM.
115
+
116
+ Args: - query: str - context: retrieved context, must be formatted otherwise the LLM cannot understand the nested
117
+ dictionaries! - sustainability: bool; if true, then the prompt is appended to instruct the LLM to use s-fairness
118
+ scores while generating the answer - params: key-value parameters to be passed to the get_context function; sets
119
+ the limit of results and whether to rerank the results
120
+
121
+ """
122
+
123
+ # what about the cities without s-fairness scores? i.e. they don't have seasonality data
124
+
125
+ prompt_with_sustainability = """You are an AI recommendation system. Your task is to recommend cities in Europe
126
+ for travel based on the user's question. You should use the provided contexts to suggest the city that is best
127
+ suited to the user's question. You recommend a list of the top 3 most sustainable cities to the user, as well as
128
+ the best month of travel. Each recommendation should also contain an explanation of why it is being recommended,
129
+ on sustainability grounds based on the context. The context contains a sustainability score for each city,
130
+ also known as the s-fairness score, along with the ideal month of travel. A lower s-fairness score indicates that
131
+ the city is a better destination for the month provided. A city without a sustainability score should not be
132
+ considered. You should only consider the s-fairness score while choosing the best city. However, your answer
133
+ should not contain the numeric score itself or any mention of the sustainability score. Your answer must begin
134
+ with "I recommend " followed by the city names and why you recommended it. Your answers are correct, high-quality,
135
+ and written by a domain expert. If the provided context does not contain the answer, simply state, "The provided
136
+ context does not have the answer. """
137
+
138
+ # format context
139
+ formatted_context = format_context(context)
140
+
141
+ if "sustainability" in params["params"] and params["params"]["sustainability"]:
142
+ prompt = generate_prompt(query, formatted_context, prompt_with_sustainability)
143
+ else:
144
+ prompt = generate_prompt(query, formatted_context)
145
+
146
+ return prompt
147
+
148
+
149
+ def test():
150
+ context_params = {
151
+ 'limit': 3,
152
+ 'reranking': 0,
153
+ 'sustainability': 0
154
+ }
155
+
156
+ query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
157
+ "in winter. "
158
+
159
+ # without sustainability
160
+ context = ir.get_context(query, **context_params)
161
+ # formatted_context = format_context(context)
162
+
163
+ without_sfairness = augment_prompt(
164
+ query=query,
165
+ context=context,
166
+ params=context_params
167
+ )
168
+
169
+ # with sustainability
170
+ context_params.update({'sustainability': 1})
171
+ s_context = ir.get_context(query, **context_params)
172
+ # s_formatted_context = format_context(s_context)
173
+
174
+ with_sfairness = augment_prompt(
175
+ query=query,
176
+ context=s_context,
177
+ params=context_params
178
+ )
179
+
180
+ return with_sfairness
181
+
182
+
183
+ if __name__ == "__main__":
184
+ prompt = test()
185
+ print(prompt)
src/data_directories.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ current = os.path.dirname(os.path.realpath(''))
4
+ parent = os.path.dirname(current)
5
+
6
+ data_parent_dir = "../../european-city-data/"
7
+ data_dir = data_parent_dir + "data-sources/"
8
+ wikivoyage_docs_dir = data_dir + "wikivoyage/"
9
+ wikivoyage_listings_dir = wikivoyage_docs_dir + "listings/"
10
+ database_dir = "../../database/wikivoyage/"
11
+ seasonality_dir = data_parent_dir + "computed/seasonality/"
12
+ popularity_dir = data_parent_dir + "computed/popularity/"
13
+ cities_csv = data_parent_dir + "city_abstracts_embeddings.csv"
14
+ prompts_dir = data_parent_dir + "rag-sustainability/prompts/"
15
+ results_dir = data_parent_dir + "rag-sustainability/results/"
src/information_retrieval/__init__.py ADDED
File without changes
src/information_retrieval/info_retrieval.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import re
3
+ import os
4
+ import json
5
+ sys.path.append("../")
6
+ from src.vectordb import vectordb
7
+ from src.sustainability import s_fairness
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
12
+
13
+
14
+ def get_travel_months(query):
15
+ """
16
+
17
+ Function to parse the user's query and search if month of travel has been provided by the user.
18
+
19
+ Args:
20
+ - query: str
21
+
22
+ """
23
+ months = [
24
+ "January", "February", "March", "April", "May", "June",
25
+ "July", "August", "September", "October", "November", "December"
26
+ ]
27
+
28
+ seasons = {
29
+ "spring": ["March", "April", "May"],
30
+ "summer": ["June", "July", "August"],
31
+ "fall": ["September", "October", "November"],
32
+ "autumn": ["September", "October", "November"],
33
+ "winter": ["December", "January", "February"]
34
+ }
35
+
36
+ months_in_query = []
37
+
38
+ for month in months:
39
+ if re.search(r'\b' + month + r'\b', query, re.IGNORECASE):
40
+ months_in_query.append(month)
41
+
42
+ # Check for seasons in the query
43
+ for season, season_months in seasons.items():
44
+ if re.search(r'\b' + season + r'\b', query, re.IGNORECASE):
45
+ months_in_query += season_months
46
+
47
+ # Return None if neither months nor seasons are found
48
+ return months_in_query
49
+
50
+
51
+ def get_wikivoyage_context(query, limit=10, reranking=0):
52
+ """
53
+
54
+ Function to retrieve the relevant documents and listings from the wikivoyage database. Works in two steps:
55
+ (i) the relevant cities are returned by the wikivoyage_docs table and (ii) then passed on to the wikivoyage listings database to retrieve further information.
56
+ The user can pass a limit of how many results the search should return as well as whether to perform reranking (uses a CrossEncoderReranker)
57
+
58
+ Args:
59
+ - query: str
60
+ - limit: int
61
+ - reranking: bool
62
+
63
+ """
64
+
65
+ # limit = params['limit']
66
+ # reranking = params['reranking']
67
+
68
+ docs = vectordb.search_wikivoyage_docs(query, limit, reranking)
69
+ logger.info("Finished getting chunked wikivoyage docs.")
70
+
71
+ results = {}
72
+ for doc in docs:
73
+ results[doc['city']] = {key: value for key, value in doc.items() if key != 'city'}
74
+ results[doc['city']]['listings'] = []
75
+
76
+ cities = [result['city'] for result in docs]
77
+
78
+ listings = vectordb.search_wikivoyage_listings(query, cities, limit, reranking)
79
+ logger.info("Finished getting wikivoyage listings.")
80
+ # logger.info(type(docs), type(listings))
81
+
82
+ for listing in listings:
83
+ # logger.info(listing['city'])
84
+ results[listing['city']]['listings'].append({
85
+ 'type': listing['type'],
86
+ 'name': listing['title'],
87
+ 'description': listing['description']
88
+ })
89
+
90
+ logger.info("Returning retrieval results.")
91
+ return results
92
+
93
+
94
+ def get_sustainability_scores(query, destinations):
95
+ """
96
+
97
+ Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
98
+ If multiple months are provided (or season), then the month with the minimum s-fairness score is chosen for the city.
99
+
100
+ Args:
101
+ - query: str
102
+ - destinations: list
103
+
104
+ """
105
+
106
+ result = [] # list of dicts of the format {city: <city>, month: <month>, }
107
+ city_scores = {}
108
+
109
+ months = get_travel_months(query)
110
+ logger.info("Finished parsing query for months.")
111
+
112
+ for city in destinations:
113
+ if city not in city_scores:
114
+ city_scores[city] = []
115
+
116
+ if not months: # no month(s) or seasons provided by the user
117
+ city_scores[city].append(s_fairness.compute_sfairness_score(city))
118
+ else:
119
+ for month in months:
120
+ city_scores[city].append(s_fairness.compute_sfairness_score(city, month))
121
+
122
+ logger.info("Finished getting s-fairness scores.")
123
+
124
+ for city, scores in city_scores.items():
125
+
126
+ no_result = 0
127
+ for score in scores:
128
+ if not score['month']:
129
+ no_result = 1
130
+ result.append({
131
+ 'city': city,
132
+ 'month': 'No data available',
133
+ 's-fairness': 'No data available'
134
+ })
135
+ break
136
+
137
+ if not no_result:
138
+ min_score = min(scores, key=lambda x: x['s-fairness'])
139
+ result.append({
140
+ 'city': city,
141
+ 'month': min_score['month'],
142
+ 's-fairness': min_score['s-fairness']
143
+ })
144
+
145
+ logger.info("Returning s-fairness results.")
146
+ return result
147
+
148
+
149
+ def get_cities(context):
150
+ """
151
+ Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context
152
+
153
+ Args:
154
+ - context: dict
155
+
156
+ """
157
+
158
+ recommended_cities = []
159
+
160
+ for city, info in context.items():
161
+ city_info = {
162
+ 'city': city,
163
+ 'country': info['country']
164
+ }
165
+
166
+ if "sustainability" in info:
167
+ city_info['month'] = info['sustainability']['month']
168
+ city_info['s-fairness'] = info['sustainability']['s-fairness']
169
+
170
+ recommended_cities.append(city_info)
171
+
172
+ if "sustainability" in info:
173
+ def get_s_fairness_value(item):
174
+ s_fairness = item['s-fairness']
175
+ if s_fairness == 'No data available':
176
+ return float('inf') # Assign a high value for "No data available"
177
+ return s_fairness
178
+
179
+ # Sort the list using the custom key
180
+ sorted_cities = sorted(recommended_cities, key=get_s_fairness_value)
181
+ return sorted_cities
182
+
183
+ else:
184
+ return recommended_cities
185
+
186
+
187
+ def get_context(query, **params):
188
+ """
189
+
190
+ Function that returns all the context: from the database, as well as the respective s-fairness scores for the
191
+ destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero
192
+ parameter "sustainability" needs to be explicitly passed to params.
193
+
194
+ Args:
195
+ - query: str
196
+ - params: dict; contains value of the limit and reranking (and sustainability)
197
+
198
+ """
199
+
200
+ limit = 3
201
+ reranking = 1
202
+
203
+ if 'limit' in params:
204
+ limit = params['limit']
205
+
206
+ if 'reranking' in params:
207
+ reranking = params['reranking']
208
+
209
+ wikivoyage_context = get_wikivoyage_context(query, limit, reranking)
210
+ recommended_cities = wikivoyage_context.keys()
211
+
212
+ if 'sustainability' in params and params['sustainability']:
213
+ s_fairness_scores = get_sustainability_scores(query, recommended_cities)
214
+
215
+ for score in s_fairness_scores:
216
+ wikivoyage_context[score['city']]['sustainability'] = {
217
+ 'month': score['month'],
218
+ 's-fairness': score['s-fairness']
219
+ }
220
+
221
+ return wikivoyage_context
222
+
223
+
224
+ def test():
225
+ queries = []
226
+ query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
227
+ "in winter. "
228
+
229
+ context = None
230
+
231
+ try:
232
+ context = get_context(query, sustainability=1)
233
+ # cities = get_cities(context)
234
+ # print(cities)
235
+ except FileNotFoundError as e:
236
+ try:
237
+ vectordb.create_wikivoyage_docs_db_and_add_data()
238
+ vectordb.create_wikivoyage_listings_db_and_add_data()
239
+
240
+ try:
241
+ context = get_context(query, sustainability=1)
242
+ # cities = get_cities(context)
243
+ # print(cities)
244
+ except Exception as e:
245
+ exc_type, exc_obj, exc_tb = sys.exc_info()
246
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
247
+ logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
248
+
249
+ except Exception as e:
250
+ logger.error(f"Error while creating DB: {e}")
251
+
252
+ except Exception as e:
253
+ exc_type, exc_obj, exc_tb = sys.exc_info()
254
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
255
+ logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
256
+
257
+ file_path = os.path.join(os.getcwd(), "test_results", "test_result.json")
258
+ with open(file_path, 'w') as file:
259
+ json.dump(context, file)
260
+
261
+ return context
262
+
263
+
264
+ if __name__ == "__main__":
265
+ context = test()
266
+
267
+ print(context)
src/pipeline.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main file to execute the TRS Pipeline.
3
+ """
4
+ import sys
5
+ from augmentation import prompt_generation as pg
6
+ from information_retrieval import info_retrieval as ir
7
+ from text_generation.models import (
8
+ Llama3,
9
+ Mistral,
10
+ Gemma2,
11
+ Llama3Point1,
12
+ Llama3Instruct,
13
+ MistralInstruct,
14
+ Llama3Point1Instruct,
15
+ Phi3SmallInstruct,
16
+ GPT4,
17
+ Gemini,
18
+ )
19
+ from text_generation import text_generation as tg
20
+ import logging
21
+
22
+ logger = logging.getLogger(__name__)
23
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
24
+
25
+ TEST_DIR = "../tests/"
26
+ MODELS = {
27
+ 'GPT-4': GPT4,
28
+ 'Llama3': Llama3,
29
+ 'Mistral': Mistral,
30
+ 'Gemma2': Gemma2,
31
+ 'Llama3.1': Llama3Point1,
32
+ 'Llama3-Instruct': Llama3Instruct,
33
+ 'Mistral-Instruct': MistralInstruct,
34
+ 'Llama3.1-Instruct': Llama3Point1Instruct,
35
+ 'Phi3-Instruct': Phi3SmallInstruct,
36
+ 'Gemini-1.0-pro': Gemini,
37
+ }
38
+
39
+
40
+ def pipeline(query: str, model_name: str, test: int = 0, **params):
41
+ """
42
+
43
+ Executes the entire RAG pipeline, provided the query and model class name.
44
+
45
+ Args:
46
+ - query: str
47
+ - model_name: string, one of the following: Llama3, Mistral, Gemma2, Llama3Point1
48
+ - test: whether the pipeline is running a test
49
+ - params:
50
+ - limit (number of results to be retained)
51
+ - reranking (binary, whether to rerank results using ColBERT or not)
52
+ - sustainability
53
+
54
+
55
+ """
56
+
57
+ model = MODELS[model_name]
58
+
59
+ context_params = {
60
+ 'limit': 5,
61
+ 'reranking': 0,
62
+ 'sustainability': 0,
63
+ }
64
+
65
+ if 'limit' in params:
66
+ context_params['limit'] = params['limit']
67
+
68
+ if 'reranking' in params:
69
+ context_params['reranking'] = params['reranking']
70
+
71
+ if 'sustainability' in params:
72
+ context_params['sustainability'] = params['sustainability']
73
+
74
+ logger.info("Retrieving context..")
75
+ try:
76
+ context = ir.get_context(query=query, **context_params)
77
+ if test:
78
+ retrieved_cities = ir.get_cities(context)
79
+ else:
80
+ retrieved_cities = None
81
+ except Exception as e:
82
+ exc_type, exc_obj, exc_tb = sys.exc_info()
83
+ logger.error(f"Error at line {exc_tb.tb_lineno} while trying to get context: {e}")
84
+ return None
85
+
86
+ logger.info("Retrieved context, augmenting prompt..")
87
+ try:
88
+ prompt = pg.augment_prompt(
89
+ query=query,
90
+ context=context,
91
+ params=context_params
92
+ )
93
+ except Exception as e:
94
+ exc_type, exc_obj, exc_tb = sys.exc_info()
95
+ logger.error(f"Error at line {exc_tb.tb_lineno} while trying to augment prompt: {e}")
96
+ return None
97
+
98
+ # return prompt
99
+
100
+ logger.info(f"Augmented prompt, initializing {model} and generating response..")
101
+ try:
102
+ response = tg.generate_response(model, prompt)
103
+ except Exception as e:
104
+ exc_type, exc_obj, exc_tb = sys.exc_info()
105
+ logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
106
+ return None
107
+
108
+ if test:
109
+ return retrieved_cities, prompt[1]['content'], response
110
+
111
+ else:
112
+ return response
113
+
114
+
115
+ if __name__ == "__main__":
116
+ # sample_query = "I'm planning a trip in the summer and I love art, history, and visiting museums. Can you
117
+ # suggest " \ "some " \ "European cities? "
118
+ sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \
119
+ "cities. "
120
+ model = "GPT-4"
121
+
122
+ pipeline_response = pipeline(
123
+ query=sample_query,
124
+ model_name=model,
125
+ sustainability=1
126
+ )
127
+
128
+ print(pipeline_response)
src/sustainability/__init__.py ADDED
File without changes
src/sustainability/s_fairness.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
9
+
10
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
11
+ sys.path.append(os.path.dirname(SCRIPT_DIR))
12
+
13
+ from data_directories import *
14
+
15
+
16
+ def get_popularity(destination):
17
+ """
18
+
19
+ Returns the popularity score for a particular destination.
20
+
21
+ Args:
22
+ - destination: str
23
+
24
+ """
25
+
26
+ parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
27
+
28
+ if "src" in os.getcwd() and os.path.exists(os.path.join(parent_path, "european-city-data")):
29
+ popularity_path = popularity_dir.replace("../../", "../")
30
+ else:
31
+ popularity_path = popularity_dir
32
+
33
+ popularity_df = pd.read_csv(popularity_path + "popularity_scores.csv")
34
+
35
+ if not len(popularity_df[popularity_df['city'] == destination]):
36
+ print(f"{destination} does not have popularity data")
37
+ return None
38
+
39
+ return popularity_df[popularity_df['city'] == destination]['weighted_pop_score'].item()
40
+
41
+
42
+ def get_seasonality(destination, month=None):
43
+ """
44
+
45
+ Returns the seasonality score for a particular destination for a particular month. If no month is provided then
46
+ the best month, i.e. month of lowest seasonality is returned.
47
+
48
+ Args:
49
+ - destination: str
50
+ - month: str (default: None)
51
+
52
+ """
53
+ parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
54
+
55
+ if "src" in os.getcwd() and os.path.exists(os.path.join(parent_path, "european-city-data")):
56
+ seasonality_path = seasonality_dir.replace("../../", "../")
57
+ else:
58
+ seasonality_path = seasonality_dir
59
+ seasonality_df = pd.read_csv(seasonality_path + "seasonality_scores.csv")
60
+
61
+ # Check if city is present in dataframe
62
+ if not len(seasonality_df[seasonality_df['city'] == destination]):
63
+ logger.info(f"{destination} does not have seasonality data for {month}")
64
+ return None, None
65
+
66
+ if month:
67
+ m = month.capitalize()[:3]
68
+ else:
69
+ seasonality_df['lowest_col'] = seasonality_df.loc[:, seasonality_df.columns != 'city'].idxmin(axis="columns")
70
+ m = seasonality_df[seasonality_df['city'] == destination]['lowest_col'].item()
71
+
72
+ # print(destination, m, seasonality_df[seasonality_df['city'] == destination][m])
73
+
74
+ return m, seasonality_df[seasonality_df['city'] == destination][m].item()
75
+
76
+
77
+ def compute_sfairness_score(destination, month=None):
78
+ """
79
+
80
+ Returns the s-fairness score for a particular destination city and (optional) month. If the destination doesn't
81
+ have popularity or seasonality scores, then the function returns None.
82
+
83
+ Args:
84
+ - destination: str
85
+ - month: str (default: None)
86
+
87
+ """
88
+ seasonality = get_seasonality(destination, month)
89
+ month = seasonality[0]
90
+ popularity = get_popularity(destination)
91
+ emissions = 0
92
+
93
+ # RECHECK
94
+ if seasonality[1] is not None and popularity is not None:
95
+ s_fairness = round(0.281 * emissions + 0.334 * popularity + 0.385 * seasonality[1], 3)
96
+ return {
97
+ 'month': month,
98
+ 's-fairness': s_fairness
99
+ }
100
+ # elif popularity is not None: # => seasonality is None
101
+ # s_fairness = 0.281 * emissions + 0.334 * popularity
102
+ # elif seasonality[1] is not None: # => popularity is None
103
+ # s_fairness = 0.281 * emissions + 0.385 * seasonality[1]
104
+ # else: # => both are non
105
+ # s_fairness = 0.281 * emissions
106
+ else:
107
+ return {
108
+ 'month': None,
109
+ 's-fairness': None
110
+ }
111
+
112
+
113
+ if __name__ == "__main__":
114
+ print(compute_sfairness_score("Paris", "Oct"))
src/text_generation/__init__.py ADDED
File without changes
src/text_generation/model_init.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from vertexai.generative_models import GenerativeModel
4
+
5
+ from dotenv import load_dotenv
6
+ from anthropic import AnthropicVertex
7
+ import os
8
+ from openai import OpenAI
9
+ from src.text_generation.vertexai_setup import initialize_vertexai_params
10
+
11
+ load_dotenv()
12
+ if "OPENAI_API_KEY" in os.environ:
13
+ OAI_API_KEY = os.environ["OPENAI_API_KEY"]
14
+ if "VERTEXAI_PROJECTID" in os.environ:
15
+ VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECTID"]
16
+
17
+
18
+ class LLMBaseClass:
19
+ """
20
+ Base Class for text generation - user needs to provide the HF model ID while instantiating the class after which
21
+ the generate method can be called to generate responses
22
+
23
+ """
24
+
25
+ def __init__(self, model_id) -> None:
26
+
27
+ match (model_id[0].lower()):
28
+ case "gpt-4o-mini": # for open AI models
29
+ self.api_key = OAI_API_KEY
30
+ self.model = OpenAI(api_key=self.api_key)
31
+ case "claude-3-5-sonnet@20240620": # for Claude through vertexAI
32
+ self.api_key = None
33
+ self.model = AnthropicVertex(region="europe-west1", project_id=VERTEXAI_PROJECT)
34
+ case "gemini-1.0-pro":
35
+ self.api_key = None
36
+ self.model = GenerativeModel(model_id[0].lower())
37
+ case _: # for HF models
38
+ self.api_key = None
39
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
40
+ self.tokenizer.pad_token = self.tokenizer.eos_token
41
+
42
+ self.tokenizer.chat_template = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- " \
43
+ "bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}{%- " \
44
+ "elif " \
45
+ "message['role'] == 'system' %}{{- '<<SYS>>\\n' + message[" \
46
+ "'content'].strip() + " \
47
+ "'\\n<</SYS>>\\n\\n' }}{%- elif message['role'] == 'assistant' %}{{- '[" \
48
+ "ASST] ' " \
49
+ "+ message['content'] + ' [/ASST]' + eos_token }}{%- endif %}{%- " \
50
+ "endfor %} "
51
+ # Initialize quantization to use less GPU
52
+ if torch.cuda.is_available():
53
+ bnb_config = BitsAndBytesConfig(
54
+ load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4",
55
+ bnb_4bit_compute_dtype=torch.bfloat16
56
+ )
57
+ else:
58
+ bnb_config = None
59
+ self.model = AutoModelForCausalLM.from_pretrained(
60
+ model_id,
61
+ torch_dtype=torch.bfloat16,
62
+ device_map="auto",
63
+ quantization_config=bnb_config,
64
+ )
65
+
66
+ self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
67
+
68
+ self.terminators = [
69
+ self.tokenizer.eos_token_id,
70
+ self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
71
+ ]
72
+
73
+ def generate(self, messages):
74
+ match (self.model_id[0].lower()):
75
+ case "gpt-4o-mini":
76
+ completion = self.model.chat.completions.create(
77
+ model=self.model_id[0],
78
+ messages=messages,
79
+ temperature=0.6,
80
+ top_p=0.9,
81
+ )
82
+ # Return the generated content from the API response
83
+ return completion.choices[0].message.content
84
+ case "claude-3-5-sonnet@20240620" | "gemini-1.0-pro":
85
+ initialize_vertexai_params()
86
+ if "claude" in self.model_id[0].lower():
87
+ message = self.model.messages.create(
88
+ max_tokens=1024,
89
+ model=self.model_id[0],
90
+ messages=[
91
+ {
92
+ "role": "user",
93
+ "content": messages[0]["content"],
94
+ }
95
+ ],
96
+ )
97
+ return message.content[0].text
98
+ else:
99
+ response = self.model.generate_content(messages[0]["content"])
100
+ return response
101
+ case _:
102
+ input_ids = self.tokenizer.apply_chat_template(
103
+ conversation=messages,
104
+ add_generation_prompt=True,
105
+ return_tensors="pt",
106
+ padding=True
107
+ ).to(self.model.device)
108
+
109
+ outputs = self.model.generate(
110
+ input_ids,
111
+ max_new_tokens=1024,
112
+ # eos_token_id=self.terminators,
113
+ pad_token_id=self.tokenizer.eos_token_id,
114
+ do_sample=True,
115
+ temperature=0.6,
116
+ top_p=0.9,
117
+ )
118
+ response = outputs[0][input_ids.shape[-1]:]
119
+
120
+ return self.tokenizer.decode(response, skip_special_tokens=True)
121
+
122
+
123
+ # database/wikivoyage/wikivoyage_listings.lance/data/e2940f51-d754-4b54-a688-004bdb8e7aa2.lance
src/text_generation/models.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.text_generation.model_init import LLMBaseClass
2
+
3
+
4
+ class Llama3(LLMBaseClass):
5
+ """
6
+ Initializes a Llama3 model object
7
+ """
8
+
9
+ def __init__(self) -> None:
10
+ self.model_id = "meta-llama/Meta-Llama-3-8B"
11
+
12
+ super().__init__(self.model_id)
13
+
14
+
15
+ class Mistral(LLMBaseClass):
16
+ """
17
+ Initializes a Mistral model object
18
+ """
19
+
20
+ def __init__(self) -> None:
21
+ self.model_id = "mistralai/Mistral-7B-v0.3"
22
+ super().__init__(self.model_id)
23
+
24
+
25
+ class Gemma2(LLMBaseClass):
26
+ """
27
+ Initializes a Gemma2 model object
28
+ """
29
+
30
+ def __init__(self) -> None:
31
+ self.model_id = "google/gemma-2-9b"
32
+ super().__init__(self.model_id)
33
+
34
+
35
+ class Llama3Point1(LLMBaseClass):
36
+ """
37
+ Initializes a Llama 3.1 object
38
+ """
39
+
40
+ def __init__(self) -> None:
41
+ self.model_id = "meta-llama/Meta-Llama-3.1-8B"
42
+ super().__init__(self.model_id)
43
+
44
+
45
+ class Llama3Instruct(LLMBaseClass):
46
+ """
47
+ Initializes a Llama 3 Instruct object
48
+ """
49
+
50
+ def __init__(self) -> None:
51
+ self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
52
+ super().__init__(self.model_id)
53
+
54
+
55
+ class MistralInstruct(LLMBaseClass):
56
+ """
57
+ Initializes a Mistral Instruct object
58
+ """
59
+
60
+ def __init__(self) -> None:
61
+ self.model_id = "mistralai/Mistral-7B-Instruct-v0.1"
62
+ super().__init__(self.model_id)
63
+
64
+
65
+ class Llama3Point1Instruct(LLMBaseClass):
66
+ """
67
+ Initializes a Llama 3.1 Instruct object
68
+ """
69
+
70
+ def __init__(self) -> None:
71
+ self.model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
72
+ super().__init__(self.model_id)
73
+
74
+
75
+ class Phi3SmallInstruct(LLMBaseClass):
76
+ """
77
+ Initializes a Phi3-Small-Instruct object
78
+ """
79
+
80
+ def __init__(self) -> None:
81
+ self.model_id = "microsoft/Phi-3-small-128k-instruct"
82
+ super().__init__(self.model_id)
83
+
84
+
85
+ class GPT4(LLMBaseClass):
86
+ """
87
+ Initializes a GPT-4 Instruct object
88
+ """
89
+
90
+ def __init__(self) -> None:
91
+ self.model_id = "gpt-4o-mini",
92
+ super().__init__(self.model_id)
93
+
94
+
95
+ class Claude3Point5Sonnet(LLMBaseClass):
96
+ """
97
+ Initializes a Claude3.5 Instruct object
98
+ """
99
+
100
+ def __init__(self) -> None:
101
+ self.model_id = "claude-3-5-sonnet@20240620",
102
+ super().__init__(self.model_id)
103
+
104
+
105
+ class Gemini(LLMBaseClass):
106
+ """
107
+ Initializes a Gemini Instruct object
108
+ """
109
+
110
+ def __init__(self) -> None:
111
+ self.model_id = "gemini-1.0-pro",
112
+ super().__init__(self.model_id)
src/text_generation/text_generation.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from augmentation import prompt_generation as pg
2
+ from information_retrieval import info_retrieval as ir
3
+ from src.text_generation.models import (
4
+ Llama3,
5
+ Mistral,
6
+ Gemma2,
7
+ Llama3Point1,
8
+ GPT4,
9
+ Claude3Point5Sonnet,
10
+ )
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
15
+
16
+
17
+ def generate_response(model, prompt):
18
+ """
19
+
20
+ Function that initializes the LLM class and calls the generate function.
21
+
22
+ Args:
23
+ - messages: list; contains the system and user prompt
24
+ - model: class; the class of the llm to be initialized
25
+
26
+ """
27
+
28
+ logger.info(f"Initializing LLM configuration for {model}")
29
+ llm = model()
30
+
31
+ logger.info("Generating response")
32
+ try:
33
+ response = llm.generate(prompt)
34
+ except Exception as e:
35
+ logger.error(f"Error while generating response for {model}: {e}")
36
+ response = 'ERROR'
37
+
38
+ return response
39
+
40
+
41
+ def test(model):
42
+ context_params = {
43
+ 'limit': 3,
44
+ 'reranking': 0
45
+ }
46
+ # model = Llama3Point1
47
+
48
+ query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
49
+ "in winter. "
50
+
51
+ # without sustainability
52
+ logger.info("Retrieving context..")
53
+ try:
54
+ context = ir.get_context(query=query, **context_params)
55
+ except Exception as e:
56
+ logger.error(f"Error while trying to get context: {e}")
57
+ return None
58
+
59
+ logger.info("Retrieved context, augmenting prompt (without sustainability)..")
60
+ try:
61
+ without_sfairness = pg.augment_prompt(
62
+ query=query,
63
+ context=context,
64
+ sustainability=0,
65
+ params=context_params
66
+ )
67
+ except Exception as e:
68
+ logger.error(f"Error while trying to augment prompt: {e}")
69
+ return None
70
+
71
+ # return without_sfairness
72
+
73
+ logger.info(f"Augmented prompt, initializing {model} and generating response..")
74
+ try:
75
+ response = generate_response(model, without_sfairness)
76
+ except Exception as e:
77
+ logger.info(f"Error while generating response: {e}")
78
+ return None
79
+
80
+ return response
81
+
82
+
83
+ if __name__ == "__main__":
84
+ response = test(Claude3Point5Sonnet)
85
+ print(response)
src/text_generation/vertexai_setup.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from dotenv import load_dotenv
3
+ from google.oauth2 import service_account
4
+ import vertexai
5
+ import os
6
+ import json
7
+ import base64
8
+
9
+ load_dotenv()
10
+ if "VERTEXAI_PROJECTID" in os.environ:
11
+ VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECTID"]
12
+
13
+
14
+ def decode_service_key():
15
+ encoded_key = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
16
+ original_service_key = json.loads(base64.b64decode(encoded_key).decode('utf-8'))
17
+ if original_service_key:
18
+ return original_service_key
19
+ return None
20
+
21
+
22
+ def initialize_vertexai_params(location: Optional[str] = "us-central1"):
23
+
24
+ creds_file_name = os.getcwd() + "/.config/application_default_credentials.json"
25
+ print(creds_file_name)
26
+ if not os.path.exists(os.path.dirname(creds_file_name)):
27
+ credentials = decode_service_key()
28
+ with open(creds_file_name, 'w', encoding='utf-8') as file:
29
+ json.dump(credentials, file, ensure_ascii=False, indent=4)
30
+
31
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file_name
32
+
33
+ service_account.Credentials.from_service_account_file(
34
+ filename=os.environ["GOOGLE_APPLICATION_CREDENTIALS"],
35
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
36
+ )
37
+ vertexai.init(project=VERTEXAI_PROJECT, location=location)
src/vectordb/__init__.py ADDED
File without changes
src/vectordb/create_db.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.vectordb.vectordb import *
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
6
+
7
+
8
+ def run():
9
+ logging.info("Creating database for Wikivoyage Documents")
10
+ try:
11
+ create_wikivoyage_docs_db_and_add_data()
12
+ except Exception as e:
13
+ logger.error(f"Error for Wikivoyage Documents: {e}")
14
+
15
+ logging.info("Creating database for Wikivoyage Listings")
16
+ try:
17
+ create_wikivoyage_listings_db_and_add_data()
18
+ print("Completed")
19
+ except Exception as e:
20
+ logger.error(f"Error for Wikivoyage Listings: {e}")
21
+
22
+
23
+ if __name__ == "__main__":
24
+ run()
src/vectordb/helpers.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import re
4
+ from sentence_transformers import SentenceTransformer
5
+ import sys
6
+
7
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
8
+ sys.path.append(os.path.dirname(SCRIPT_DIR))
9
+
10
+ from data_directories import *
11
+
12
+
13
+ def create_chunks(city, country, text):
14
+ """
15
+
16
+ Helper function that creates chunks given paragraph(s) of text based on implicit sections in the text.
17
+
18
+ Args:
19
+ - city: str
20
+ - country: str
21
+ - text: str; document that needs to be chunked
22
+
23
+ """
24
+
25
+ for i, line in enumerate(text):
26
+ if line[0] == "\n":
27
+ del text[i]
28
+
29
+ index = 0
30
+ chunks = []
31
+ pattern = re.compile("==")
32
+ ignore = re.compile("===")
33
+ section = 'Introduction'
34
+ for i, line in enumerate(text):
35
+ if pattern.search(line) and not ignore.search(line):
36
+ chunk = ''.join(text[index:i])
37
+ chunks.append({
38
+ 'city': city,
39
+ 'country': country,
40
+ 'section': section,
41
+ 'text': chunk,
42
+ # 'vector': f'city: {city}, country: {country}, section: {section}, text: {chunk}'
43
+ })
44
+ index = i + 1
45
+ section = re.sub(pattern, '', line).strip()
46
+
47
+ df = pd.DataFrame(chunks)
48
+ return df
49
+
50
+
51
+ def read_docs():
52
+ """
53
+
54
+ Helper function that reads all of the Wikivoyage documents containing information about the city.
55
+
56
+ """
57
+
58
+ df = pd.DataFrame()
59
+ cities = pd.read_csv(cities_csv)
60
+ for file_name in os.listdir(wikivoyage_docs_dir + "cleaned/"):
61
+ city = file_name.split(".")[0]
62
+ # print(city)
63
+ country = cities[cities['city'] == city]['country'].item()
64
+ with open(wikivoyage_docs_dir + "cleaned/" + file_name) as file:
65
+ text = file.readlines()
66
+ chunk_df = create_chunks(city, country, text)
67
+ df = pd.concat([df, chunk_df])
68
+
69
+ return df
70
+
71
+
72
+ def read_listings():
73
+ """
74
+
75
+ Helper function that reads the Wikivoyage listings csv containing tabular information about 144 cities.
76
+
77
+ """
78
+ df = pd.read_csv(wikivoyage_listings_dir + "wikivoyage-listings-cleaned.csv")
79
+ cities = pd.read_csv(cities_csv)
80
+
81
+ def find_country(city):
82
+ return cities[cities['city'] == city]['country'].values[0]
83
+
84
+ df['country'] = df['city'].apply(find_country)
85
+
86
+ return df
87
+
88
+
89
+ def preprocess_df(df):
90
+ """
91
+
92
+ Helper function that preprocesses the dataframe containing chunks of text and removes hyperlinks and strips the \n from the text.
93
+
94
+ Args:
95
+ - df: dataframe
96
+
97
+ """
98
+ section_counts = df['section'].value_counts()
99
+ sections_to_keep = section_counts[section_counts > 150].index
100
+ filtered_df = df[df['section'].isin(sections_to_keep)]
101
+
102
+ def preprocess_text(s):
103
+ s = re.sub(r'http\S+', '', s)
104
+ s = re.sub(r'=+', '', s)
105
+ s = s.strip()
106
+ return s
107
+
108
+ filtered_df['text'] = filtered_df['text'].apply(preprocess_text)
109
+
110
+ return filtered_df
111
+
112
+
113
+ def compute_wv_docs_embeddings(df):
114
+ """
115
+
116
+ Helper function that computes embeddings for the text. The all-MiniLM-L6-v2 embedding model is used.
117
+
118
+ Args:
119
+ - df: dataframe
120
+
121
+ """
122
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
123
+ vector_dimension = model.get_sentence_embedding_dimension()
124
+
125
+ print("Computing embeddings")
126
+ embeddings = []
127
+ for i, row in df.iterrows():
128
+ emb = model.encode(row['combined'], show_progress_bar=True).tolist()
129
+ embeddings.append(emb)
130
+
131
+ print("Finished computing embeddings for wikivoyage documents.")
132
+ df['vector'] = embeddings
133
+ # df.to_csv(wv_embeddings + "wikivoyage-listings-embeddings.csv")
134
+ # print("Finished saving file.")
135
+ return df
136
+
137
+
138
+ def embed_query(query):
139
+ """
140
+
141
+ Helper function that returns the embedded query.
142
+
143
+ Args:
144
+ - query: str
145
+
146
+ """
147
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
148
+ # vector_dimension = model.get_sentence_embedding_dimension()
149
+ embedding = model.encode(query).tolist()
150
+ return embedding
src/vectordb/lancedb_init.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from lancedb.embeddings import get_registry
4
+ from lancedb.pydantic import LanceModel, Vector
5
+
6
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
7
+ sys.path.append(os.path.dirname(SCRIPT_DIR))
8
+
9
+
10
+ model = get_registry().get("sentence-transformers").create()
11
+
12
+
13
+ class WikivoyageDocuments(LanceModel):
14
+ """
15
+
16
+ Schema definition for the Wikivoyage Documents table.
17
+
18
+ """
19
+ city: str = model.SourceField()
20
+ country: str = model.SourceField()
21
+ section: str = model.SourceField()
22
+ text: str = model.SourceField()
23
+ vector: Vector(model.ndims()) = model.VectorField()
24
+
25
+
26
+ class WikivoyageListings(LanceModel):
27
+ """
28
+
29
+ Schema definition for the Wikivoyage Listings table.
30
+
31
+ """
32
+ city: str = model.SourceField()
33
+ type: str = model.SourceField()
34
+ title: str = model.SourceField()
35
+ description: str = model.SourceField()
36
+ country: str = model.SourceField()
37
+ vector: Vector(model.ndims()) = model.VectorField()
src/vectordb/vectordb.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from src import *
2
+ from src.vectordb.helpers import *
3
+ from src.vectordb.lancedb_init import *
4
+ import logging
5
+ import os
6
+ import lancedb
7
+ from lancedb.rerankers import ColbertReranker
8
+
9
+ import sys
10
+ logger = logging.getLogger(__name__)
11
+ from typing import Optional
12
+
13
+ # db = lancedb.connect("/tmp/db")
14
+
15
+ def create_wikivoyage_docs_db_and_add_data():
16
+ """
17
+
18
+ Creates wikivoyage documents table and ingests data
19
+
20
+ """
21
+ uri = database_dir
22
+ current_dir = os.path.split(os.getcwd())[1]
23
+
24
+ if "src" or "tests" in current_dir: # hacky way to get the correct path
25
+ uri = uri.replace("../../", "../")
26
+
27
+ db = lancedb.connect(uri)
28
+ logger.info("Connected to DB. Reading data now...")
29
+ df = read_docs()
30
+ filtered_df = preprocess_df(df)
31
+ logger.info("Finished reading data, attempting to create table and ingest the data...")
32
+
33
+ db.drop_table("wikivoyage_documents", ignore_missing=True)
34
+ table = db.create_table("wikivoyage_documents", schema=WikivoyageDocuments)
35
+
36
+ table.add(filtered_df.to_dict('records'))
37
+ logger.info("Completed ingestion.")
38
+
39
+
40
+ def create_wikivoyage_listings_db_and_add_data():
41
+ """
42
+
43
+ Creates wikivoyage listings table and ingests data
44
+
45
+ """
46
+ uri = database_dir
47
+ current_dir = os.path.split(os.getcwd())[1]
48
+
49
+ if "src" or "tests" in current_dir: # hacky way to get the correct path
50
+ uri = uri.replace("../../", "../")
51
+
52
+ db = lancedb.connect(uri)
53
+ logger.info("Connected to DB. Reading data now...")
54
+ df = read_listings()
55
+ logger.info("Finished reading data, attempting to create table and ingest the data...")
56
+ # filtered_df = preprocess_df(df)
57
+
58
+ db.drop_table("wikivoyage_listings", ignore_missing=True)
59
+ table = db.create_table("wikivoyage_listings", schema=WikivoyageListings)
60
+
61
+ table.add(df.astype('str').to_dict('records'))
62
+ logger.info("Completed ingestion.")
63
+
64
+
65
+ def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0, run_local: Optional[bool] = False):
66
+ """
67
+
68
+ Function to search the wikivoyage database an return most relevant chunked docs.
69
+
70
+ Args:
71
+ - query: str
72
+ - limit: number of results (default is 10)
73
+ - reranking: bool (0 or 1), if activated, CrossEncoderReranker is used.
74
+
75
+ """
76
+ if run_local:
77
+ uri = database_dir
78
+ current_dir = os.path.split(os.getcwd())[1]
79
+
80
+ if "src" or "tests" in current_dir: # hacky way to get the correct path
81
+ uri = uri.replace("../../", "../")
82
+ else:
83
+ uri = os.environ["BUCKET_NAME"]
84
+ # print(uri)
85
+ try:
86
+ db = lancedb.connect(uri)
87
+ except Exception as e:
88
+ logger.error(f"Error while connecting to DB: {e}")
89
+
90
+ logger.info("Connected to Wikivoyage DB.")
91
+
92
+ # query_embedding = embed_query(query)
93
+ table = db.open_table("wikivoyage_documents")
94
+
95
+ if reranking:
96
+ try:
97
+ reranker = ColbertReranker(column='text')
98
+ results = table.search(query) \
99
+ .metric('cosine') \
100
+ .rerank(reranker=reranker) \
101
+ .limit(limit) \
102
+ .to_list()
103
+ except Exception as e:
104
+ exc_type, exc_obj, exc_tb = sys.exc_info()
105
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
106
+ logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
107
+
108
+ else:
109
+ try:
110
+ results = table.search(query) \
111
+ .limit(limit) \
112
+ .metric('cosine') \
113
+ .to_list()
114
+ except Exception as e:
115
+ exc_type, exc_obj, exc_tb = sys.exc_info()
116
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
117
+ logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
118
+
119
+ logger.info("Found the most relevant documents.")
120
+ city_lists = [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in
121
+ results]
122
+
123
+ # context = [f"city: {r['city']}, country: {r['country']}, name: {r['title']}, description: {r['description']}"
124
+ # for r in results]
125
+
126
+ return city_lists
127
+
128
+
129
+ def search_wikivoyage_listings(query, cities, limit=10, reranking=0):
130
+ """
131
+
132
+ Function to search the wikivoyage database an return most relevant listings, post-filtered by the recommended
133
+ cities provided by wikivoyage_documents table.
134
+
135
+ Args:
136
+ - query: str
137
+ - cities: list
138
+ - limit: number of results (default is 10)
139
+ - reranking: bool (0 or 1), if activated, CrossEncoderReranker is used.
140
+
141
+ """
142
+ uri = database_dir
143
+ current_dir = os.path.split(os.getcwd())[1]
144
+
145
+ if "src" or "tests" in current_dir: # hacky way to get the correct path
146
+ uri = uri.replace("../../", "../")
147
+
148
+ db = lancedb.connect(uri)
149
+ logger.info("Connected to Wikivoyage Listings DB.")
150
+
151
+ table = db.open_table("wikivoyage_listings")
152
+
153
+ cities_filter = f"city IN {tuple(cities)}"
154
+
155
+ if reranking:
156
+ try:
157
+ reranker = ColbertReranker(column='description')
158
+ results = table.search(query) \
159
+ .where(cities_filter) \
160
+ .metric('cosine') \
161
+ .rerank(reranker=reranker) \
162
+ .limit(limit) \
163
+ .to_list()
164
+
165
+ except Exception as e:
166
+ exc_type, exc_obj, exc_tb = sys.exc_info()
167
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
168
+ logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
169
+
170
+ else:
171
+ try:
172
+ results = table.search(query) \
173
+ .where(cities_filter) \
174
+ .metric('cosine') \
175
+ .limit(limit) \
176
+ .to_list()
177
+ except Exception as e:
178
+ exc_type, exc_obj, exc_tb = sys.exc_info()
179
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
180
+ logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
181
+
182
+ logger.info("Found the most relevant documents.")
183
+ city_listings = [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'],
184
+ "description": r['description']} for r in results]
185
+
186
+ return city_listings