Upload folder using huggingface_hub
Browse files- app.py +1 -2
- commonsenseConstraint.py +735 -0
- eval.py +181 -0
- hardConstraint.py +266 -0
- requirements.txt +1 -2
app.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
import sys
|
3 |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "./leaderboard/evaluation")))
|
4 |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "./leaderboard")))
|
5 |
-
print(sys.path)
|
6 |
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
7 |
import json
|
8 |
import datetime
|
@@ -19,7 +18,7 @@ from huggingface_hub import HfApi
|
|
19 |
# InfoStrings
|
20 |
# from scorer import question_scorer
|
21 |
from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
|
22 |
-
from
|
23 |
|
24 |
TOKEN = os.environ.get("TOKEN", None)
|
25 |
|
|
|
2 |
import sys
|
3 |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "./leaderboard/evaluation")))
|
4 |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "./leaderboard")))
|
|
|
5 |
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
6 |
import json
|
7 |
import datetime
|
|
|
18 |
# InfoStrings
|
19 |
# from scorer import question_scorer
|
20 |
from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
|
21 |
+
from eval import eval_score
|
22 |
|
23 |
TOKEN = os.environ.get("TOKEN", None)
|
24 |
|
commonsenseConstraint.py
ADDED
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from annotation.src.utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames
|
2 |
+
from tools.flights.apis import Flights
|
3 |
+
from tools.accommodations.apis import Accommodations
|
4 |
+
from tools.restaurants.apis import Restaurants
|
5 |
+
from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
|
6 |
+
from tools.attractions.apis import Attractions
|
7 |
+
import math
|
8 |
+
import json
|
9 |
+
import re
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
from tqdm import tqdm
|
13 |
+
import argparse
|
14 |
+
|
15 |
+
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
|
16 |
+
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
17 |
+
|
18 |
+
flight = Flights()
|
19 |
+
accommodation = Accommodations()
|
20 |
+
restaurants = Restaurants()
|
21 |
+
googleDistanceMatrix = GoogleDistanceMatrix()
|
22 |
+
attractions = Attractions()
|
23 |
+
|
24 |
+
city_state_set = open('../database/background/citySet_with_states.txt','r').read().split('\n')
|
25 |
+
city_state_map = {x:y for x,y in [unit.split('\t') for unit in city_state_set]}
|
26 |
+
|
27 |
+
|
28 |
+
def load_line_json_data(filename):
|
29 |
+
data = []
|
30 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
31 |
+
for line in f.read().strip().split('\n'):
|
32 |
+
unit = json.loads(line)
|
33 |
+
data.append(unit)
|
34 |
+
return data
|
35 |
+
|
36 |
+
|
37 |
+
def count_consecutive_values(lst):
|
38 |
+
if not lst:
|
39 |
+
return []
|
40 |
+
|
41 |
+
result = []
|
42 |
+
current_string = lst[0]
|
43 |
+
count = 1
|
44 |
+
|
45 |
+
for i in range(1, len(lst)):
|
46 |
+
if lst[i] == current_string:
|
47 |
+
count += 1
|
48 |
+
else:
|
49 |
+
result.append((current_string, count))
|
50 |
+
current_string = lst[i]
|
51 |
+
count = 1
|
52 |
+
|
53 |
+
result.append((current_string, count)) # Add the last group of values
|
54 |
+
return result
|
55 |
+
|
56 |
+
|
57 |
+
def transportation_match(text: str):
|
58 |
+
|
59 |
+
if 'taxi' in text.lower():
|
60 |
+
return 'Taxi'
|
61 |
+
|
62 |
+
elif 'self-driving' in text.lower():
|
63 |
+
return 'Self-driving'
|
64 |
+
|
65 |
+
elif 'flight' in text.lower():
|
66 |
+
return 'Flight'
|
67 |
+
|
68 |
+
|
69 |
+
def extract_from_to(text: str):
|
70 |
+
"""
|
71 |
+
Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
- text (str): The input string.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
- tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None).
|
78 |
+
"""
|
79 |
+
pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)"
|
80 |
+
matches = re.search(pattern, text)
|
81 |
+
return matches.groups() if matches else (None, None)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
def is_valid_city_sequence(city_list):
|
86 |
+
"""
|
87 |
+
Checks if the city sequence is valid. A valid sequence has every city (except the first and last)
|
88 |
+
appearing consecutively, and no city should appear again once its sequence is over.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
- city_list (list): List of cities.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
- bool: True if the sequence is valid, False otherwise.
|
95 |
+
"""
|
96 |
+
|
97 |
+
# If the list has less than 3 cities, it's invalid.
|
98 |
+
if len(city_list) < 3:
|
99 |
+
return False
|
100 |
+
|
101 |
+
# Set to keep track of visited cities
|
102 |
+
visited_cities = set()
|
103 |
+
|
104 |
+
i = 0
|
105 |
+
while i < len(city_list):
|
106 |
+
city = city_list[i]
|
107 |
+
|
108 |
+
# If the city was already visited, it's invalid.
|
109 |
+
if city in visited_cities and (i != 0 and i != len(city_list) - 1):
|
110 |
+
return False
|
111 |
+
|
112 |
+
# Count the consecutive occurrences of the city
|
113 |
+
count = 0
|
114 |
+
while i < len(city_list) and city_list[i] == city:
|
115 |
+
count += 1
|
116 |
+
i += 1
|
117 |
+
|
118 |
+
# If the city appeared only once in the medium, it's invalid.
|
119 |
+
if count == 1 and 0 < i - 1 < len(city_list) - 1:
|
120 |
+
return False
|
121 |
+
|
122 |
+
visited_cities.add(city)
|
123 |
+
|
124 |
+
return True
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
def is_reasonalbe_visiting_city(question, tested_data):
|
129 |
+
|
130 |
+
city_list = []
|
131 |
+
|
132 |
+
# print(tested_data)
|
133 |
+
for i in range(min(question['days'],len(tested_data))):
|
134 |
+
city_value = tested_data[i]['current_city']
|
135 |
+
|
136 |
+
if 'from' in city_value:
|
137 |
+
city1, city2 = extract_from_to(city_value)
|
138 |
+
city1 = extract_before_parenthesis(city1)
|
139 |
+
city2 = extract_before_parenthesis(city2)
|
140 |
+
if i==0 and city1 != question['org']:
|
141 |
+
return False, f"The first day's city should be {question['org']}."
|
142 |
+
|
143 |
+
city_list += [city1, city2]
|
144 |
+
|
145 |
+
else:
|
146 |
+
city_list.append(extract_before_parenthesis(city_value))
|
147 |
+
|
148 |
+
if city_list[0] != city_list[-1]:
|
149 |
+
return False, "The trip should be a closed circle."
|
150 |
+
|
151 |
+
if not is_valid_city_sequence(city_list):
|
152 |
+
return False, "The city sequence is invalid."
|
153 |
+
|
154 |
+
for idx, city in enumerate(city_list):
|
155 |
+
if city not in city_state_map:
|
156 |
+
return False, f"{city} is not a valid city."
|
157 |
+
if idx not in [0,len(city_list)-1] and question['days'] >3 and city_state_map[city] != question['dest']:
|
158 |
+
return False, f"{city} is not in {question['dest']}."
|
159 |
+
|
160 |
+
return True, None
|
161 |
+
|
162 |
+
|
163 |
+
def is_valid_restaurants(question, tested_data):
|
164 |
+
|
165 |
+
restaurants_list = []
|
166 |
+
|
167 |
+
for i in range(min(question['days'],len(tested_data))):
|
168 |
+
unit = tested_data[i]
|
169 |
+
|
170 |
+
if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
|
171 |
+
if unit['breakfast'] not in restaurants_list:
|
172 |
+
restaurants_list.append(unit['breakfast'])
|
173 |
+
else:
|
174 |
+
return False, f"The restaurant in day {i+1} breakfast is repeated."
|
175 |
+
# elif 'breakfast' not in unit :
|
176 |
+
# return False, f"No Breakfast Info."
|
177 |
+
|
178 |
+
if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
|
179 |
+
if unit['lunch'] not in restaurants_list:
|
180 |
+
restaurants_list.append(unit['lunch'])
|
181 |
+
else:
|
182 |
+
return False, f"The restaurant in day {i+1} lunch {unit['lunch']} is repeated."
|
183 |
+
# elif 'lunch' not in unit:
|
184 |
+
# return False, f"No Lunch Info."
|
185 |
+
|
186 |
+
if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
|
187 |
+
if unit['dinner'] not in restaurants_list:
|
188 |
+
restaurants_list.append(unit['dinner'])
|
189 |
+
else:
|
190 |
+
return False, f"The restaurant in day {i+1} dinner is repeated."
|
191 |
+
# elif 'dinner' not in unit:
|
192 |
+
# return False, f"No Dinner Info."
|
193 |
+
|
194 |
+
return True, None
|
195 |
+
|
196 |
+
def is_valid_attractions(question, tested_data):
|
197 |
+
|
198 |
+
attractions_list = []
|
199 |
+
|
200 |
+
for i in range(min(question['days'],len(tested_data))):
|
201 |
+
unit = tested_data[i]
|
202 |
+
|
203 |
+
if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
|
204 |
+
for attraction in unit['attraction'].split(';')[:-1]:
|
205 |
+
if attraction not in attractions_list:
|
206 |
+
attractions_list.append(attraction)
|
207 |
+
else:
|
208 |
+
return False, f"The attraction '{attraction}' in day {i+1} is repeated."
|
209 |
+
|
210 |
+
# elif 'attraction' not in unit:
|
211 |
+
# return False, f"No Attraction Info."
|
212 |
+
|
213 |
+
return True, None
|
214 |
+
|
215 |
+
def is_valid_transportation(question, tested_data):
|
216 |
+
|
217 |
+
if tested_data[0]['transportation'] and tested_data[0]['transportation'] != '-':
|
218 |
+
transportation_list = [transportation_match(tested_data[0]['transportation'])]
|
219 |
+
|
220 |
+
else:
|
221 |
+
return False, "The transportation in day 1 should not be empty."
|
222 |
+
|
223 |
+
for i in range(min(question['days'],len(tested_data))):
|
224 |
+
unit = tested_data[i]
|
225 |
+
|
226 |
+
if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
|
227 |
+
transportation_list.append(transportation_match(unit['transportation']))
|
228 |
+
# elif 'transportation' not in unit:
|
229 |
+
# return False, f"No Transportation Info."
|
230 |
+
|
231 |
+
if (('Self-driving' in transportation_list) and ('Flight' in transportation_list)) or (('Taxi' in transportation_list) and ('Self-driving' in transportation_list)):
|
232 |
+
return False, "The transportation is conflicting."
|
233 |
+
|
234 |
+
return True, None
|
235 |
+
|
236 |
+
def is_valid_information_in_current_city(question, tested_data):
|
237 |
+
|
238 |
+
for i in range(min(question['days'],len(tested_data))):
|
239 |
+
unit = tested_data[i]
|
240 |
+
current_city = unit['current_city']
|
241 |
+
final_city_list = []
|
242 |
+
|
243 |
+
if 'from' in current_city:
|
244 |
+
city1, city2 = extract_from_to(current_city)
|
245 |
+
city1 = extract_before_parenthesis(city1)
|
246 |
+
city2 = extract_before_parenthesis(city2)
|
247 |
+
final_city_list = [city1, city2]
|
248 |
+
else:
|
249 |
+
final_city_list = extract_before_parenthesis(current_city)
|
250 |
+
|
251 |
+
if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
|
252 |
+
for city in final_city_list:
|
253 |
+
if city not in unit['transportation']:
|
254 |
+
# print(city)
|
255 |
+
return False, f"The transportation in day {i+1} is invalid city choice."
|
256 |
+
# elif 'transportation' not in unit:
|
257 |
+
# return False, f"No Transportation Info."
|
258 |
+
|
259 |
+
if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
|
260 |
+
|
261 |
+
flag = False
|
262 |
+
|
263 |
+
for city in final_city_list:
|
264 |
+
if city in unit['breakfast']:
|
265 |
+
flag = True
|
266 |
+
|
267 |
+
if not flag:
|
268 |
+
return False, f"The breakfast in day {i+1} is invalid city choice."
|
269 |
+
# elif 'breakfast' not in unit:
|
270 |
+
# return False, f"No Breakfast Info."
|
271 |
+
|
272 |
+
if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
|
273 |
+
flag = False
|
274 |
+
|
275 |
+
for city in final_city_list:
|
276 |
+
if city in unit['lunch']:
|
277 |
+
flag = True
|
278 |
+
|
279 |
+
if not flag:
|
280 |
+
return False, f"The lunch in day {i+1} is invalid city choice."
|
281 |
+
# elif 'lunch' not in unit:
|
282 |
+
# return False, f"No Lunch Info."
|
283 |
+
|
284 |
+
if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
|
285 |
+
flag = False
|
286 |
+
|
287 |
+
for city in final_city_list:
|
288 |
+
if city in unit['dinner']:
|
289 |
+
flag = True
|
290 |
+
|
291 |
+
if not flag:
|
292 |
+
return False, f"The dinner in day {i+1} is invalid city choice."
|
293 |
+
# elif 'dinner' not in unit:
|
294 |
+
# return False, f"No Dinner Info."
|
295 |
+
|
296 |
+
if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
|
297 |
+
|
298 |
+
attraction_list = unit['attraction'].split(';')[:-1]
|
299 |
+
|
300 |
+
for attraction in attraction_list:
|
301 |
+
flag = False
|
302 |
+
for city in final_city_list:
|
303 |
+
if city in attraction:
|
304 |
+
flag = True
|
305 |
+
if not flag:
|
306 |
+
return False, f"The attraction in day {i+1} is invalid city choice."
|
307 |
+
|
308 |
+
# elif 'attraction' not in unit:
|
309 |
+
# return False, f"No Attraction Info."
|
310 |
+
|
311 |
+
|
312 |
+
if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
|
313 |
+
|
314 |
+
if final_city_list[-1] not in unit['accommodation']:
|
315 |
+
return False, f"The accommodation in day {i+1} is invalid city choice."
|
316 |
+
|
317 |
+
# elif 'accommodation' not in unit:
|
318 |
+
# return False, f"No Accommodation Info."
|
319 |
+
|
320 |
+
return True, None
|
321 |
+
|
322 |
+
# hallucination
|
323 |
+
def is_valid_information_in_sandbox(question, tested_data):
|
324 |
+
|
325 |
+
for i in range(min(question['days'],len(tested_data))):
|
326 |
+
unit = tested_data[i]
|
327 |
+
|
328 |
+
if unit['transportation'] and unit['transportation'] != '-':
|
329 |
+
value = unit['transportation']
|
330 |
+
org_city, dest_city = extract_from_to(value)
|
331 |
+
if org_city == None or dest_city == None:
|
332 |
+
org_city, dest_city = extract_from_to(unit['current_city'])
|
333 |
+
if 'flight number' in value.lower():
|
334 |
+
try:
|
335 |
+
org_city = extract_before_parenthesis(org_city)
|
336 |
+
dest_city = extract_before_parenthesis(dest_city)
|
337 |
+
except TypeError:
|
338 |
+
raise ValueError("The transportation {} in day {} can not be parsed.".format(value,i+1))
|
339 |
+
# print(value)
|
340 |
+
if len(flight.data[(flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]) & (flight.data['OriginCityName']==org_city) & (flight.data['DestCityName']==dest_city)]) < 1:
|
341 |
+
return False, f"The flight number in day {i+1} is invalid in the sandbox."
|
342 |
+
|
343 |
+
elif 'self-driving' in value.lower() or 'taxi' in value.lower():
|
344 |
+
try:
|
345 |
+
org_city = extract_before_parenthesis(org_city)
|
346 |
+
dest_city = extract_before_parenthesis(dest_city)
|
347 |
+
except TypeError:
|
348 |
+
org_city = '-'
|
349 |
+
dest_city = '-'
|
350 |
+
print("The transportation {} in day {} can not be parsed and '-' will be used instead.".format(value,i+1))
|
351 |
+
|
352 |
+
if 'self-driving' in value.lower():
|
353 |
+
if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='self-driving')['cost'] == None:
|
354 |
+
return False, f"The self-driving in day {i+1} is invalid in the sandbox."
|
355 |
+
else:
|
356 |
+
if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='taxi')['cost'] == None:
|
357 |
+
return False, f"The taxi in day {i+1} is invalid in the sandbox."
|
358 |
+
|
359 |
+
if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
|
360 |
+
name, city = get_valid_name_city(unit['breakfast'])
|
361 |
+
if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
|
362 |
+
return False, f"The breakfast in day {i+1} is invalid in the sandbox."
|
363 |
+
# elif 'breakfast' not in unit:
|
364 |
+
# return False, f"No Breakfast Info."
|
365 |
+
|
366 |
+
if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
|
367 |
+
name, city = get_valid_name_city(unit['lunch'])
|
368 |
+
if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
|
369 |
+
return False, f"The lunch in day {i+1} is invalid in the sandbox."
|
370 |
+
# elif 'lunch' not in unit:
|
371 |
+
# return False, f"No Lunch Info."
|
372 |
+
|
373 |
+
if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
|
374 |
+
name, city = get_valid_name_city(unit['dinner'])
|
375 |
+
if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
|
376 |
+
return False, f"The dinner in day {i+1} is invalid in the sandbox."
|
377 |
+
# elif 'dinner' not in unit:
|
378 |
+
# return False, f"No Dinner Info."
|
379 |
+
|
380 |
+
if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
|
381 |
+
attractions_list = unit['attraction'].split(';')[:-1]
|
382 |
+
for attraction in attractions_list:
|
383 |
+
name, city = get_valid_name_city(attraction)
|
384 |
+
if len(attractions.data[(attractions.data['Name'].astype(str).str.contains(re.escape(name))) & (attractions.data['City'] == city)]) < 1:
|
385 |
+
return False, f"The attraction {attraction} in day {i+1} is invalid in the sandbox."
|
386 |
+
# elif 'attraction' not in unit:
|
387 |
+
# return False, f"No Attraction Info."
|
388 |
+
|
389 |
+
if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
|
390 |
+
name, city = get_valid_name_city(unit['accommodation'])
|
391 |
+
# print(name,city)
|
392 |
+
# print(accommodation.data[accommodation.data['NAME'].astype(str).str.contains(re.escape(name))])
|
393 |
+
if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) < 1:
|
394 |
+
return False, f"The accommodation in day {i+1} is invalid in the sandbox."
|
395 |
+
# elif 'accommodation' not in unit:
|
396 |
+
# return False, f"No Accommodation Info."
|
397 |
+
|
398 |
+
return True, None
|
399 |
+
|
400 |
+
|
401 |
+
def is_valid_accommodaton(question, tested_data):
|
402 |
+
data = []
|
403 |
+
for i in range(min(question['days'],len(tested_data))):
|
404 |
+
unit = tested_data[i]
|
405 |
+
|
406 |
+
if 'accommodation' not in unit:
|
407 |
+
return False, f"No Accommodation Info."
|
408 |
+
|
409 |
+
data.append(unit['accommodation'])
|
410 |
+
# data = [unit['accommodation'] for unit in tested_data]
|
411 |
+
consectutive_accommodation = count_consecutive_values(data)
|
412 |
+
for unit in consectutive_accommodation:
|
413 |
+
# print(unit)
|
414 |
+
if unit and unit[0] not in ['-',''] :
|
415 |
+
name, city = get_valid_name_city(unit[0])
|
416 |
+
# print(unit[0],name,city)
|
417 |
+
# try:
|
418 |
+
if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) == 1 and unit[1] < accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)].iloc[0]['minimum nights']:
|
419 |
+
return False, f"The accommodation {unit[0]} do not obey the minumum nights rule."
|
420 |
+
# can not parse data
|
421 |
+
# except re.error:
|
422 |
+
# continue
|
423 |
+
|
424 |
+
return True, None
|
425 |
+
|
426 |
+
def is_valid_visiting_city_number(question, tested_data):
|
427 |
+
|
428 |
+
city_set = set()
|
429 |
+
|
430 |
+
|
431 |
+
for i in range(min(question['days'],len(tested_data))):
|
432 |
+
city_value = tested_data[i]['current_city']
|
433 |
+
|
434 |
+
if 'from' in city_value:
|
435 |
+
city1, city2 = extract_from_to(city_value)
|
436 |
+
city1 = extract_before_parenthesis(city1)
|
437 |
+
city2 = extract_before_parenthesis(city2)
|
438 |
+
if i==0 and city1 != question['org']:
|
439 |
+
return False, f"The first day's city should be {question['org']}."
|
440 |
+
|
441 |
+
city_set.add(city1)
|
442 |
+
city_set.add(city2)
|
443 |
+
|
444 |
+
else:
|
445 |
+
city_set.add(extract_before_parenthesis(city_value))
|
446 |
+
|
447 |
+
city_set.discard(question['org'])
|
448 |
+
|
449 |
+
if len(city_set) != question['visiting_city_number']:
|
450 |
+
return False, f"The number of visiting cities should be {question['visiting_city_number']}."
|
451 |
+
|
452 |
+
return True, None
|
453 |
+
|
454 |
+
def is_valid_days(question, tested_data):
|
455 |
+
lens = 0
|
456 |
+
for i in range(min(question['days'],len(tested_data))):
|
457 |
+
if tested_data[i] != {} and tested_data[i]['current_city'] != "You don't need to fill in the information for this or later days.":
|
458 |
+
lens += 1
|
459 |
+
|
460 |
+
if lens != question['days']:
|
461 |
+
# print(lens)
|
462 |
+
return False, f"The number of days should be {question['days']}."
|
463 |
+
else:
|
464 |
+
return True, None
|
465 |
+
|
466 |
+
def is_not_absent(question, tested_data):
|
467 |
+
needed_info = 6 * question['days']
|
468 |
+
total_valid_info = 0
|
469 |
+
|
470 |
+
if not is_valid_days(question, tested_data)[0]:
|
471 |
+
return False, "Invalid Days"
|
472 |
+
|
473 |
+
if not is_valid_visiting_city_number(question, tested_data)[0]:
|
474 |
+
return False, "Invalid City Number"
|
475 |
+
|
476 |
+
for i in range(min(question['days'],len(tested_data))):
|
477 |
+
unit = tested_data[i]
|
478 |
+
|
479 |
+
if 'transportation' not in unit:
|
480 |
+
return False, f"No Transportation Info."
|
481 |
+
|
482 |
+
if 'breakfast' not in unit:
|
483 |
+
return False, f"No Breakfast Info."
|
484 |
+
|
485 |
+
if 'lunch' not in unit:
|
486 |
+
return False, f"No Lunch Info."
|
487 |
+
|
488 |
+
if 'dinner' not in unit:
|
489 |
+
return False, f"No Dinner Info."
|
490 |
+
|
491 |
+
if 'attraction' not in unit:
|
492 |
+
return False, f"No Attraction Info."
|
493 |
+
|
494 |
+
if 'accommodation' not in unit:
|
495 |
+
return False, f"No Accommodation Info."
|
496 |
+
|
497 |
+
if ('from ' in unit['current_city'] or 'to ' in unit['current_city']) and unit['transportation'] in ['','-']:
|
498 |
+
return False, f"No transportation in day {i+1} is not allowed."
|
499 |
+
|
500 |
+
if ('from ' not in unit['current_city'] and ' to ' not in unit['current_city']) and unit['attraction'] in ['','-']:
|
501 |
+
return False, f"No attaction in day {i+1} is not allowed."
|
502 |
+
|
503 |
+
if i != question['days'] - 1 and unit['accommodation'] in ['','-']:
|
504 |
+
return False, f"No accommodation in day {i+1} is not allowed."
|
505 |
+
|
506 |
+
if (unit['breakfast'] in ['','-'] or unit['lunch'] in ['','-'] or unit['dinner'] in ['','-']) and 'from ' not in unit['current_city']:
|
507 |
+
return False, f"No meal in day {i+1} is not allowed."
|
508 |
+
|
509 |
+
|
510 |
+
for key in unit:
|
511 |
+
if unit[key] and unit[key] != '-':
|
512 |
+
total_valid_info += 1
|
513 |
+
|
514 |
+
|
515 |
+
if total_valid_info * 1.0 / needed_info < 0.5:
|
516 |
+
return False, f"The absent information is more than 50%."
|
517 |
+
|
518 |
+
return True, None
|
519 |
+
|
520 |
+
|
521 |
+
def evaluation(query_data, tested_data):
|
522 |
+
return_info = {}
|
523 |
+
return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data)
|
524 |
+
return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data)
|
525 |
+
return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data)
|
526 |
+
return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data)
|
527 |
+
return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data)
|
528 |
+
return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data)
|
529 |
+
return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data)
|
530 |
+
return_info['is_not_absent'] = is_not_absent(query_data, tested_data)
|
531 |
+
return return_info
|
532 |
+
|
533 |
+
def boolean_evaluation(query_data, tested_data):
|
534 |
+
return_info = {}
|
535 |
+
return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data)
|
536 |
+
return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data)
|
537 |
+
return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data)
|
538 |
+
return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data)
|
539 |
+
return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data)
|
540 |
+
return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data)
|
541 |
+
return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data)
|
542 |
+
return_info['is_not_absent'] = is_not_absent(query_data, tested_data)
|
543 |
+
for key in return_info:
|
544 |
+
if return_info[key][0] == False:
|
545 |
+
print(return_info[key][1])
|
546 |
+
return False
|
547 |
+
return True
|
548 |
+
|
549 |
+
# if __name__ == '__main__':
|
550 |
+
# number_list = extract_numbers_from_filenames('/home/xj/toolAugEnv/code/toolConstraint/data/annotation/lrz')
|
551 |
+
# # json_data = json.load(open('/home/xj/toolAugEnv/code/toolConstraint/data/annotation/x/annotation_4.json'))
|
552 |
+
# query_data = load_line_json_data('/home/xj/toolAugEnv/code/toolConstraint/data/query/lrz.jsonl')
|
553 |
+
# for idx in number_list:
|
554 |
+
# json_data = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/lrz/annotation_{idx}.json'))
|
555 |
+
# print(str(idx), evaluation(query_data[idx-1], json_data))
|
556 |
+
# # json_data = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/plan_{idx}.json'))
|
557 |
+
# # query_data = load_line_json_data('/home/xj/toolAugEnv/code/toolConstraint/data/query/test.jsonl')[idx-1]
|
558 |
+
# # help me write all function name in this file, just the name
|
559 |
+
# #
|
560 |
+
# # list all function name in this file
|
561 |
+
# # ['is_reasonalbe_visiting_city', 'is_valiable_restaurants', 'is_valiable_attractions', 'is_valiable_transportation', 'is_valid_information_in_current_city', 'is_valid_information_in_sandbox']
|
562 |
+
# # print(is_valiable_restaurants(query_data, json_data))
|
563 |
+
|
564 |
+
# if __name__ == "__main__":
|
565 |
+
# user = 'zk'
|
566 |
+
# query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
|
567 |
+
# idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
568 |
+
# commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
569 |
+
# for idx in idx_number_list:
|
570 |
+
# print(idx)
|
571 |
+
# query_data = query_data_list[idx-1]
|
572 |
+
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/plan_{idx}.json'))
|
573 |
+
# # generated_plan = generated_plan[:-1]
|
574 |
+
# if generated_plan[-1]['gpt-3.5-turbo-16k-result'] != 'Plan Fail':
|
575 |
+
# info_box = evaluation(query_data, generated_plan[-1]['gpt-3.5-turbo-16k-result'])
|
576 |
+
# generated_plan[-1]['toolAug-commonsense'] = info_box
|
577 |
+
# else:
|
578 |
+
# generated_plan[-1]['toolAug-commonsense'] = None
|
579 |
+
# info_box = None
|
580 |
+
# commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
|
581 |
+
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/plan_{idx}.json','w') as f:
|
582 |
+
# json.dump(generated_plan,f)
|
583 |
+
|
584 |
+
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/commonsense_statistic.json','w') as f:
|
585 |
+
# json.dump(commonsense_statistic,f)
|
586 |
+
|
587 |
+
# if __name__ == "__main__":
|
588 |
+
# user = 'all'
|
589 |
+
# model_type = ['chatgpt','gpt4','greedy_search'][2]
|
590 |
+
# query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
|
591 |
+
# # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
592 |
+
# idx_number_list = [i for i in range(1,501)]
|
593 |
+
# commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
594 |
+
|
595 |
+
# for idx in idx_number_list:
|
596 |
+
# print(idx)
|
597 |
+
# query_data = query_data_list[idx-1]
|
598 |
+
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/plan_{idx}.json'))
|
599 |
+
# # generated_plan = generated_plan[:-1]
|
600 |
+
# if model_type == 'greedy_search':
|
601 |
+
# info_box = evaluation(query_data, generated_plan[-1][f'greedy_search_plan'])
|
602 |
+
# else:
|
603 |
+
# info_box = evaluation(query_data, generated_plan[-1][f'{model_type}_human_collected_info_results_parsed'])
|
604 |
+
# generated_plan[-1][f'{model_type}_with_human_collected_commonsense'] = info_box
|
605 |
+
# commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
|
606 |
+
|
607 |
+
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/plan_{idx}.json','w') as f:
|
608 |
+
# json.dump(generated_plan,f)
|
609 |
+
|
610 |
+
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/{model_type}_with_human_collected_commonsense_statistic.json','w') as f:
|
611 |
+
# json.dump(commonsense_statistic,f)
|
612 |
+
|
613 |
+
|
614 |
+
# if __name__ == "__main__":
|
615 |
+
# user = 'all'
|
616 |
+
# query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
|
617 |
+
# idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
618 |
+
# hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
619 |
+
# not_satified = []
|
620 |
+
# for idx in tqdm(idx_number_list):
|
621 |
+
# # print(idx)
|
622 |
+
# query_data = query_data_list[idx-1]
|
623 |
+
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}/annotation_{idx}.json'))
|
624 |
+
|
625 |
+
# if not boolean_evaluation(query_data, generated_plan):
|
626 |
+
# not_satified.append(idx)
|
627 |
+
# print(idx)
|
628 |
+
# generated_plan = generated_plan[:-1]
|
629 |
+
# print(not_satified)
|
630 |
+
|
631 |
+
if __name__ == "__main__":
|
632 |
+
set_type = ["train",'dev','test'][0]
|
633 |
+
query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/query/query.jsonl')
|
634 |
+
# idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/plan')
|
635 |
+
commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
636 |
+
not_satified = []
|
637 |
+
# print( idx_number_list)
|
638 |
+
for idx in tqdm(range(1,len(query_data_list)+1)):
|
639 |
+
# print(idx)
|
640 |
+
query_data = query_data_list[idx-1]
|
641 |
+
generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/plan/plan_{idx}.json'))
|
642 |
+
try:
|
643 |
+
store_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{idx}.json'))
|
644 |
+
except FileNotFoundError:
|
645 |
+
store_plan = [{}]
|
646 |
+
info_box = evaluation(query_data,generated_plan[1])
|
647 |
+
# if not boolean_evaluation(query_data, generated_plan[1]):
|
648 |
+
# not_satified.append(idx)
|
649 |
+
# print(idx)
|
650 |
+
# print(store_plan[-1])
|
651 |
+
store_plan[-1][f'human_anno_commonsense_constraint'] = info_box
|
652 |
+
with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{idx}.json','w') as f:
|
653 |
+
json.dump(store_plan,f)
|
654 |
+
commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
|
655 |
+
print(not_satified)
|
656 |
+
with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/human_anno_commonsense_constraint.json','w') as f:
|
657 |
+
json.dump(commonsense_statistic,f)
|
658 |
+
|
659 |
+
# if __name__ == "__main__":
|
660 |
+
# user = 'all'
|
661 |
+
# model_type = ['chatgpt','gpt4'][1]
|
662 |
+
# query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
|
663 |
+
# # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
664 |
+
# idx_number_list = [i for i in range(1,501)]
|
665 |
+
# commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
666 |
+
# cnt = 0
|
667 |
+
# for idx in idx_number_list:
|
668 |
+
# # print(idx)
|
669 |
+
# query_data = query_data_list[idx-1]
|
670 |
+
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre/{user}/plan_{idx}.json'))[-1]['gpt4_human_collected_info_results_parsed']
|
671 |
+
# # generated_plan = generated_plan[:-1]
|
672 |
+
|
673 |
+
# if not boolean_evaluation(query_data, generated_plan):
|
674 |
+
# cnt += 1
|
675 |
+
# print(idx)
|
676 |
+
# print(cnt)
|
677 |
+
|
678 |
+
# if __name__ == "__main__":
|
679 |
+
# parser = argparse.ArgumentParser(description="")
|
680 |
+
# # model_type = ['gpt-3.5-turbo-1106','gpt-4-1106-preview','greedy_search','mistral-7B-32K','gemini2','mixtral','gpt-3.5-turbo-11062'][-1]
|
681 |
+
# # method = ['direct','cot','react','reflexion','tool-use'][-1]
|
682 |
+
# # set_type = ['dev','test'][0]
|
683 |
+
# parser.add_argument("--model_type", type=str, default="gpt-3.5-turbo-1106")
|
684 |
+
# parser.add_argument("--method", type=str, default="direct")
|
685 |
+
# parser.add_argument("--set_type", type=str, default="dev")
|
686 |
+
# args = parser.parse_args()
|
687 |
+
# directory = f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{args.set_type}'
|
688 |
+
# query_data_list = load_line_json_data(os.path.join(directory, 'query/query.jsonl'))
|
689 |
+
# # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
690 |
+
# idx_number_list = [i for i in range(1,len(query_data_list)+1)]
|
691 |
+
# commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
692 |
+
# deliver_cnt = 0
|
693 |
+
# if args.method == 'tool-use':
|
694 |
+
# suffix = ''
|
695 |
+
# else:
|
696 |
+
# suffix = '_with_human_info'
|
697 |
+
# for idx in tqdm(idx_number_list):
|
698 |
+
# # print(idx)
|
699 |
+
# query_data = query_data_list[idx-1]
|
700 |
+
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/plan_{idx}.json'))
|
701 |
+
# # generated_plan = generated_plan[:-1]
|
702 |
+
# if args.model_type == 'greedy_search':
|
703 |
+
# info_box = evaluation(query_data, generated_plan[-1][f'greedy_search_plan'])
|
704 |
+
# else:
|
705 |
+
# if args.method == 'tool-use':
|
706 |
+
# suffix2 = ''
|
707 |
+
# else:
|
708 |
+
# suffix2 = '_collected'
|
709 |
+
# if generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] and generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results']!='Max Token Length Exceeded.':
|
710 |
+
# try:
|
711 |
+
# info_box = evaluation(query_data, generated_plan[-1][f'{args.model_type}_{args.method}{suffix}_results_parsed'])
|
712 |
+
# except KeyError:
|
713 |
+
# info_box = None
|
714 |
+
# generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] = ""
|
715 |
+
# except IndexError:
|
716 |
+
# info_box = None
|
717 |
+
# generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] = ""
|
718 |
+
# else:
|
719 |
+
# info_box = None
|
720 |
+
# if info_box:
|
721 |
+
# deliver_cnt += 1
|
722 |
+
# generated_plan[-1][f'{args.model_type}_{args.method}{suffix}_commonsense_constraint'] = info_box
|
723 |
+
# commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
|
724 |
+
|
725 |
+
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/plan_{idx}.json','w') as f:
|
726 |
+
# json.dump(generated_plan,f)
|
727 |
+
|
728 |
+
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/{args.model_type}_{args.method}{suffix}_commonsense_constraint.json','w') as f:
|
729 |
+
# json.dump(commonsense_statistic,f)
|
730 |
+
|
731 |
+
# if args.set_type == 'dev':
|
732 |
+
# print(f"Model:{args.model_type} Method:{args.method} Set: {args.set_type} \nDeliver Rate: {deliver_cnt/180}" )
|
733 |
+
# elif args.set_type == 'test':
|
734 |
+
# print(f"Model:{args.model_type} Method:{args.method} Set: {args.set_type} \nDeliver Rate: {deliver_cnt/1000}" )
|
735 |
+
|
eval.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from commonsenseConstraint import evaluation as commonsense_eval
|
2 |
+
from hardConstraint import evaluation as hard_eval
|
3 |
+
import json
|
4 |
+
from tqdm import tqdm
|
5 |
+
from datasets import load_dataset
|
6 |
+
|
7 |
+
|
8 |
+
def load_line_json_data(filename):
|
9 |
+
data = []
|
10 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
11 |
+
for line in f.read().strip().split('\n'):
|
12 |
+
unit = json.loads(line)
|
13 |
+
data.append(unit)
|
14 |
+
return data
|
15 |
+
|
16 |
+
def count_true_false(data):
|
17 |
+
"""Count the number of true and false values in a list."""
|
18 |
+
true_count = data.count(True)
|
19 |
+
false_count = data.count(False)
|
20 |
+
return true_count, false_count
|
21 |
+
|
22 |
+
def statistics(commonsense_statistic):
|
23 |
+
"""Generate statistics for each level and day in the given data with a different structure."""
|
24 |
+
result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic}
|
25 |
+
|
26 |
+
for level, days in commonsense_statistic.items():
|
27 |
+
for day, dicts in days.items():
|
28 |
+
for dct in dicts:
|
29 |
+
if dct:
|
30 |
+
for key, data in dct.items():
|
31 |
+
true_count, false_count = count_true_false(data)
|
32 |
+
if key not in result[level][day]:
|
33 |
+
result[level][day][key] = {"true": 0, "false": 0}
|
34 |
+
result[level][day][key]["true"] += true_count
|
35 |
+
result[level][day][key]["false"] += false_count
|
36 |
+
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
def eval_score(validation_or_test: str, file_path: str, TOKEN):
|
41 |
+
|
42 |
+
if validation_or_test == 'validation':
|
43 |
+
query_data_list = load_dataset('osunlp/TravelBenchEval','validation',token=TOKEN)['validation']
|
44 |
+
elif validation_or_test == 'test':
|
45 |
+
query_data_list = load_dataset('osunlp/TravelBenchEval','test',token=TOKEN)['test']
|
46 |
+
|
47 |
+
query_data_list = [x for x in query_data_list]
|
48 |
+
hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
49 |
+
commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
50 |
+
tested_plans = load_line_json_data(file_path)
|
51 |
+
delivery_cnt = 0
|
52 |
+
plan_constraint_store = []
|
53 |
+
for idx in tqdm(range(0,len(query_data_list))):
|
54 |
+
query_data = query_data_list[idx]
|
55 |
+
tested_plan = tested_plans[idx]
|
56 |
+
if type(query_data) == str:
|
57 |
+
query_data = eval(query_data)
|
58 |
+
if type(tested_plan) == str:
|
59 |
+
tested_plan = eval(tested_plan)
|
60 |
+
if type(query_data['local_constraint']) == str:
|
61 |
+
query_data['local_constraint'] = eval(query_data['local_constraint'])
|
62 |
+
|
63 |
+
if tested_plan['plan']:
|
64 |
+
delivery_cnt += 1
|
65 |
+
commonsense_info_box = commonsense_eval(query_data,tested_plan['plan'])
|
66 |
+
else:
|
67 |
+
commonsense_info_box = None
|
68 |
+
|
69 |
+
if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]:
|
70 |
+
hard_info_box = hard_eval(query_data,tested_plan['plan'])
|
71 |
+
else:
|
72 |
+
hard_info_box = None
|
73 |
+
|
74 |
+
plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box})
|
75 |
+
|
76 |
+
commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box)
|
77 |
+
hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box)
|
78 |
+
|
79 |
+
commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic)
|
80 |
+
hardConstraint_statistic_processed = statistics(hardConstraint_statistic)
|
81 |
+
# print(commonsenseConstraint_statistic_processed)
|
82 |
+
# print(hardConstraint_statistic_processed)
|
83 |
+
constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
|
84 |
+
constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'}
|
85 |
+
mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
|
86 |
+
count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']}
|
87 |
+
|
88 |
+
for unit in query_data_list:
|
89 |
+
count_record[unit['level']][unit['days']] += 1
|
90 |
+
for key in constraint_record['medium'][3]:
|
91 |
+
if unit['local_constraint'][key] != None:
|
92 |
+
constraint_record[unit['level']][unit['days']][key] += 1
|
93 |
+
mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1
|
94 |
+
|
95 |
+
data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']}
|
96 |
+
|
97 |
+
constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}}
|
98 |
+
|
99 |
+
for constraint in ['commonsense','hard']:
|
100 |
+
if constraint == 'commonsense':
|
101 |
+
constraint_statistic = commonsenseConstraint_statistic_processed
|
102 |
+
elif constraint == 'hard':
|
103 |
+
constraint_statistic = hardConstraint_statistic_processed
|
104 |
+
|
105 |
+
key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']}
|
106 |
+
|
107 |
+
for key in constraint_statistic:
|
108 |
+
# level
|
109 |
+
for key2 in constraint_statistic[key]:
|
110 |
+
# day
|
111 |
+
# print(key2)
|
112 |
+
# key2 = eval(key2)
|
113 |
+
if key2 == -1:
|
114 |
+
print(constraint_statistic[key])
|
115 |
+
exit(0)
|
116 |
+
for key3 in key_dict[constraint]:
|
117 |
+
data_record[key][key2].append('0/0')
|
118 |
+
if key3 in constraint_statistic[key][key2]:
|
119 |
+
constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true']
|
120 |
+
if constraint == 'hard':
|
121 |
+
if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']:
|
122 |
+
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
|
123 |
+
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
|
124 |
+
elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']:
|
125 |
+
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
|
126 |
+
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
|
127 |
+
else:
|
128 |
+
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
|
129 |
+
if key3 in ['valid_cost','valid_visitng_city_number','valid_days']:
|
130 |
+
constraint_dis_record[constraint]['total'] += count_record[key][key2]
|
131 |
+
else:
|
132 |
+
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
|
133 |
+
constraint_dis_record[constraint]['total'] += count_record[key][key2]
|
134 |
+
|
135 |
+
final_all_cnt = 0
|
136 |
+
final_commonsense_cnt = 0
|
137 |
+
final_hardConstraint_cnt = 0
|
138 |
+
final_all_cnt_map = {level:0 for level in ['easy','medium','hard']}
|
139 |
+
for idx in (range(0,len(query_data_list))):
|
140 |
+
if plan_constraint_store[idx]['commonsense_constraint']:
|
141 |
+
final_commonsense_pass = True
|
142 |
+
final_hardConstraint_pass = True
|
143 |
+
for item in plan_constraint_store[idx]['commonsense_constraint']:
|
144 |
+
if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]:
|
145 |
+
final_commonsense_pass = False
|
146 |
+
break
|
147 |
+
if plan_constraint_store[idx]['hard_constraint'] is None:
|
148 |
+
continue
|
149 |
+
for item in plan_constraint_store[idx]['hard_constraint']:
|
150 |
+
if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False:
|
151 |
+
final_hardConstraint_pass = False
|
152 |
+
break
|
153 |
+
|
154 |
+
if final_commonsense_pass:
|
155 |
+
final_commonsense_cnt += 1
|
156 |
+
if final_hardConstraint_pass:
|
157 |
+
final_hardConstraint_cnt += 1
|
158 |
+
if final_commonsense_pass and final_hardConstraint_pass:
|
159 |
+
final_all_cnt += 1
|
160 |
+
final_all_cnt_map[query_data_list[idx]['level']] += 1
|
161 |
+
|
162 |
+
result = {}
|
163 |
+
|
164 |
+
if validation_or_test == 'validation':
|
165 |
+
result['Delivery Rate'] = delivery_cnt / 180
|
166 |
+
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440
|
167 |
+
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180
|
168 |
+
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420
|
169 |
+
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180
|
170 |
+
result['Final Pass Rate'] = final_all_cnt / 180
|
171 |
+
|
172 |
+
elif validation_or_test == 'test':
|
173 |
+
result['Delivery Rate'] = delivery_cnt / 1000
|
174 |
+
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000
|
175 |
+
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000
|
176 |
+
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290
|
177 |
+
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000
|
178 |
+
result['Final Pass Rate'] = final_all_cnt / 1000
|
179 |
+
|
180 |
+
return result
|
181 |
+
|
hardConstraint.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from annotation.src.utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames
|
2 |
+
from tools.flights.apis import Flights
|
3 |
+
from tools.accommodations.apis import Accommodations
|
4 |
+
from tools.restaurants.apis import Restaurants
|
5 |
+
from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
|
6 |
+
from tools.attractions.apis import Attractions
|
7 |
+
import math
|
8 |
+
import json
|
9 |
+
import re
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
from tqdm import tqdm
|
14 |
+
import argparse
|
15 |
+
|
16 |
+
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
|
17 |
+
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
18 |
+
|
19 |
+
|
20 |
+
flight = Flights()
|
21 |
+
accommodation = Accommodations()
|
22 |
+
restaurants = Restaurants()
|
23 |
+
googleDistanceMatrix = GoogleDistanceMatrix()
|
24 |
+
attractions = Attractions()
|
25 |
+
|
26 |
+
|
27 |
+
def load_line_json_data(filename):
|
28 |
+
data = []
|
29 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
30 |
+
for line in f.read().strip().split('\n'):
|
31 |
+
unit = json.loads(line)
|
32 |
+
data.append(unit)
|
33 |
+
return data
|
34 |
+
|
35 |
+
|
36 |
+
def convert_bool_values(item):
|
37 |
+
if isinstance(item, dict):
|
38 |
+
# If the item is a dictionary, recurse on each value
|
39 |
+
return {key: convert_bool_values(value) for key, value in item.items()}
|
40 |
+
elif isinstance(item, list):
|
41 |
+
# If the item is a list, recurse on each item in the list
|
42 |
+
return [convert_bool_values(value) for value in item]
|
43 |
+
elif isinstance(item, tuple):
|
44 |
+
# If the item is a tuple, recurse on each item in the tuple and repackage as a tuple
|
45 |
+
return tuple(convert_bool_values(value) for value in item)
|
46 |
+
elif isinstance(item, np.bool_): # Here we check for numpy's bool_ type
|
47 |
+
# If the item is a numpy bool_, convert it to a standard Python bool
|
48 |
+
return bool(item)
|
49 |
+
else:
|
50 |
+
# If the item is any other type, return it unchanged
|
51 |
+
return item
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
def extract_from_to(text: str):
|
57 |
+
"""
|
58 |
+
Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
- text (str): The input string.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
- tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None).
|
65 |
+
"""
|
66 |
+
pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)"
|
67 |
+
matches = re.search(pattern, text)
|
68 |
+
return matches.groups() if matches else (None, None)
|
69 |
+
|
70 |
+
|
71 |
+
def get_total_cost(question, tested_data):
|
72 |
+
total_cost = 0
|
73 |
+
for i in range(min(question['days'],len(tested_data))):
|
74 |
+
unit = tested_data[i]
|
75 |
+
# transporation
|
76 |
+
if unit['transportation'] and unit['transportation'] != '-':
|
77 |
+
value = unit['transportation']
|
78 |
+
org_city, dest_city = extract_from_to(value)
|
79 |
+
if org_city == None or dest_city == None:
|
80 |
+
org_city, dest_city = extract_from_to(unit['current_city'])
|
81 |
+
|
82 |
+
if org_city == None or dest_city == None:
|
83 |
+
pass
|
84 |
+
else:
|
85 |
+
if 'flight number' in value.lower():
|
86 |
+
res = flight.data[flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]]
|
87 |
+
if len(res) > 0:
|
88 |
+
total_cost += res['Price'].values[0] * question['people_number']
|
89 |
+
|
90 |
+
elif 'self-driving' in value.lower() or 'taxi' in value.lower():
|
91 |
+
if 'self-driving' in value.lower():
|
92 |
+
# print(org_city,dest_city)
|
93 |
+
cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost']
|
94 |
+
total_cost += cost * math.ceil(question['people_number'] * 1.0 / 5)
|
95 |
+
else:
|
96 |
+
cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost']
|
97 |
+
total_cost += cost * math.ceil(question['people_number'] * 1.0 / 4)
|
98 |
+
|
99 |
+
# breakfast
|
100 |
+
if unit['breakfast'] and unit['breakfast'] != '-':
|
101 |
+
name, city = get_valid_name_city(unit['breakfast'])
|
102 |
+
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
103 |
+
if len(res) > 0:
|
104 |
+
total_cost += res['Average Cost'].values[0] * question['people_number']
|
105 |
+
|
106 |
+
|
107 |
+
# lunch
|
108 |
+
if unit['lunch'] and unit['lunch'] != '-':
|
109 |
+
name, city = get_valid_name_city(unit['lunch'])
|
110 |
+
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
111 |
+
if len(res) > 0:
|
112 |
+
total_cost += res['Average Cost'].values[0] * question['people_number']
|
113 |
+
|
114 |
+
# dinner
|
115 |
+
if unit['dinner'] and unit['dinner'] != '-':
|
116 |
+
name, city = get_valid_name_city(unit['dinner'])
|
117 |
+
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
118 |
+
if len(res) > 0:
|
119 |
+
total_cost += res['Average Cost'].values[0] * question['people_number']
|
120 |
+
|
121 |
+
# accommodation
|
122 |
+
if unit['accommodation'] and unit['accommodation'] != '-':
|
123 |
+
name, city = get_valid_name_city(unit['accommodation'])
|
124 |
+
res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
|
125 |
+
if len(res) > 0:
|
126 |
+
total_cost += res['price'].values[0] * math.ceil(question['people_number'] * 1.0 / res['maximum occupancy'].values[0])
|
127 |
+
# print(total_cost)
|
128 |
+
return total_cost
|
129 |
+
|
130 |
+
|
131 |
+
def is_valid_room_rule(question, tested_data):
|
132 |
+
|
133 |
+
if question['local_constraint']['house rule'] is None:
|
134 |
+
return None,None
|
135 |
+
|
136 |
+
for i in range(min(question['days'],len(tested_data))):
|
137 |
+
unit = tested_data[i]
|
138 |
+
if unit['accommodation'] and unit['accommodation'] != '-':
|
139 |
+
name, city = get_valid_name_city(unit['accommodation'])
|
140 |
+
res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
|
141 |
+
if len(res) > 0:
|
142 |
+
if question['local_constraint']['house rule'] == 'smoking' and 'No smoking' in str(res['house_rules'].values[0]):
|
143 |
+
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
144 |
+
if question['local_constraint']['house rule'] == 'parities' and 'No parties' in str(res['house_rules'].values[0]):
|
145 |
+
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
146 |
+
if question['local_constraint']['house rule'] == 'children under 10' and 'No children under 10' in str(res['house_rules'].values[0]):
|
147 |
+
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
148 |
+
if question['local_constraint']['house rule'] == 'visitors' and 'No visitors' in str(res['house_rules'].values[0]):
|
149 |
+
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
150 |
+
if question['local_constraint']['house rule'] == 'pets' and 'No pets' in str(res['house_rules'].values[0]):
|
151 |
+
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
152 |
+
|
153 |
+
|
154 |
+
return True, None
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
def is_valid_cuisine(question, tested_data):
|
159 |
+
cuisine_set = set()
|
160 |
+
if question['local_constraint']['cuisine']:
|
161 |
+
for i in range(min(question['days'],len(tested_data))):
|
162 |
+
unit = tested_data[i]
|
163 |
+
|
164 |
+
if unit['breakfast'] and unit['breakfast'] != '-':
|
165 |
+
name, city = get_valid_name_city(unit['breakfast'])
|
166 |
+
if city == question['org']:
|
167 |
+
continue
|
168 |
+
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
169 |
+
if len(res) > 0:
|
170 |
+
for cuisine in question['local_constraint']['cuisine']:
|
171 |
+
if cuisine in res.iloc[0]['Cuisines']:
|
172 |
+
cuisine_set.add(cuisine)
|
173 |
+
|
174 |
+
if unit['lunch'] and unit['lunch'] != '-':
|
175 |
+
name, city = get_valid_name_city(unit['lunch'])
|
176 |
+
if city == question['org']:
|
177 |
+
continue
|
178 |
+
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
179 |
+
if len(res) > 0:
|
180 |
+
for cuisine in question['local_constraint']['cuisine']:
|
181 |
+
if cuisine in res.iloc[0]['Cuisines']:
|
182 |
+
cuisine_set.add(cuisine)
|
183 |
+
|
184 |
+
if unit['dinner'] and unit['dinner'] != '-':
|
185 |
+
name, city = get_valid_name_city(unit['dinner'])
|
186 |
+
if city == question['org']:
|
187 |
+
continue
|
188 |
+
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
189 |
+
if len(res) > 0:
|
190 |
+
for cuisine in question['local_constraint']['cuisine']:
|
191 |
+
if cuisine in res.iloc[0]['Cuisines']:
|
192 |
+
cuisine_set.add(cuisine)
|
193 |
+
|
194 |
+
if len(cuisine_set) == len(question['local_constraint']['cuisine']):
|
195 |
+
return True, None
|
196 |
+
else:
|
197 |
+
# judge which cuisine is not satisfied
|
198 |
+
for cuisine in question['local_constraint']['cuisine']:
|
199 |
+
if cuisine not in cuisine_set:
|
200 |
+
return False, f"The cuisine {cuisine} is not satisfied."
|
201 |
+
# return False, f"The cuisine should be {question['local_constraint']['cuisine']}."
|
202 |
+
else:
|
203 |
+
return None,None
|
204 |
+
|
205 |
+
|
206 |
+
def is_valid_transportation(question, tested_data):
|
207 |
+
if question['local_constraint']['transportation'] is None:
|
208 |
+
return None,None
|
209 |
+
for i in range(min(question['days'],len(tested_data))):
|
210 |
+
unit = tested_data[i]
|
211 |
+
if unit['transportation'] and unit['transportation'] != '-':
|
212 |
+
value = unit['transportation']
|
213 |
+
if question['local_constraint']['transportation'] == 'no flight' and 'Flight' in value:
|
214 |
+
return False, f"The transportation should not be {question['local_constraint']['transportation']}."
|
215 |
+
elif question['local_constraint']['transportation'] == 'no self-driving' and 'Self-driving' in value:
|
216 |
+
return False, f"The transportation should not be {question['local_constraint']['transportation']}."
|
217 |
+
|
218 |
+
return True, None
|
219 |
+
|
220 |
+
|
221 |
+
def is_valid_room_type(question, tested_data):
|
222 |
+
if question['local_constraint']['room type'] is None:
|
223 |
+
return None,None
|
224 |
+
for i in range(min(question['days'],len(tested_data))):
|
225 |
+
unit = tested_data[i]
|
226 |
+
if unit['accommodation'] and unit['accommodation'] != '-':
|
227 |
+
name, city = get_valid_name_city(unit['accommodation'])
|
228 |
+
res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
|
229 |
+
if len(res) > 0:
|
230 |
+
if question['local_constraint']['room type'] == 'not shared room' and res['room type'].values[0] == 'Shared room':
|
231 |
+
return False, f"The room type should be {question['local_constraint']['room type']}."
|
232 |
+
# "shared room", "not shared room", "private room", "entire room"
|
233 |
+
elif question['local_constraint']['room type'] == 'shared room' and res['room type'].values[0] != 'Shared room':
|
234 |
+
return False, f"The room type should be {question['local_constraint']['room type']}."
|
235 |
+
|
236 |
+
elif question['local_constraint']['room type'] == 'private room' and res['room type'].values[0] != 'Private room':
|
237 |
+
return False, f"The room type should be {question['local_constraint']['room type']}."
|
238 |
+
|
239 |
+
elif question['local_constraint']['room type'] == 'entire room' and res['room type'].values[0] != 'Entire home/apt':
|
240 |
+
return False, f"The room type should be {question['local_constraint']['room type']}."
|
241 |
+
|
242 |
+
return True, None
|
243 |
+
|
244 |
+
|
245 |
+
def evaluation(query_data, tested_data):
|
246 |
+
return_info = {}
|
247 |
+
return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data)
|
248 |
+
return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data)
|
249 |
+
return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data)
|
250 |
+
return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data)
|
251 |
+
return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None)
|
252 |
+
return return_info
|
253 |
+
|
254 |
+
def boolean_evaluation(query_data, tested_data):
|
255 |
+
return_info = {}
|
256 |
+
return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data)
|
257 |
+
return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data)
|
258 |
+
return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data)
|
259 |
+
return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data)
|
260 |
+
return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None)
|
261 |
+
for key in return_info:
|
262 |
+
if return_info[key][0] == False:
|
263 |
+
print(key)
|
264 |
+
return False
|
265 |
+
return True
|
266 |
+
|
requirements.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
datasets==2.16.1
|
2 |
gradio==3.50.2
|
3 |
-
huggingface-hub==0.20.2
|
4 |
-
APScheduler==3.10.1
|
|
|
1 |
datasets==2.16.1
|
2 |
gradio==3.50.2
|
3 |
+
huggingface-hub==0.20.2
|
|