from pymongo import MongoClient import datetime import os class ResponseDb: def __init__(self): # Set up the connection mongodb_username=os.environ['mongodb_username'] mongodb_pw=os.environ['mongodb_pw'] mongodb_cluster_url=os.environ['mongodb_cluster_url'] self.client = MongoClient(f"mongodb+srv://{mongodb_username}:{mongodb_pw}@{mongodb_cluster_url}/?retryWrites=true&w=majority") self.db = self.client['vqa-game'] self.collection = self.db['vqa-game'] def add(self, dialogue_id, task_id, turn, question, response): curr_datetime = datetime.datetime.now() document = {"dialogue_id":dialogue_id, "task_id":task_id, "turn":turn, "question":question, "response":response, "datetime":curr_datetime} result = self.collection.insert_one(document) def get(self): return self.collection.find() def get_code(taskid, history, top_pred): taskid = int(taskid) mongodb_username=os.environ['mongodb_username_2'] mongodb_pw=os.environ['mongodb_pw_2'] mongodb_cluster_url=os.environ['mongodb_cluster_url_2'] client = MongoClient(f"mongodb+srv://{mongodb_username}:{mongodb_pw}@{mongodb_cluster_url}/?retryWrites=true&w=majority") db = client['vqa-codes'] collection = db['vqa-codes'] threshold_dict = {1001: 6, 1002: 2, 1003: 4, 1004: 2} if int(taskid) in threshold_dict: threshold = threshold_dict[int(taskid)] if len(history)<=threshold and top_pred == 0: return list(collection.find({"taskid":int(taskid)}))[0]['code'] else: return list(collection.find({"taskid":3000-int(taskid)}))[0]['code'] return list(collection.find({"taskid":taskid}))[0]['code']