Spaces:
Sleeping
Sleeping
import os | |
import datetime as dt | |
import streamlit as st | |
from streamlit.logger import get_logger | |
from pymongo.mongo_client import MongoClient | |
from pymongo.server_api import ServerApi | |
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS | |
DB_URL = os.environ['MONGO_URL'] | |
DB_USR = os.environ['MONGO_USR'] | |
DB_PWD = os.environ['MONGO_PWD'] | |
logger = get_logger(__name__) | |
def get_db_client(): | |
uri = f"mongodb+srv://{DB_USR}:{DB_PWD}@{DB_URL}/?retryWrites=true&w=majority" | |
# Create a new client and connect to the server | |
client = MongoClient(uri, server_api=ServerApi('1')) | |
# Send a ping to confirm a successful connection | |
try: | |
client.admin.command('ping') | |
logger.info(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!") | |
return client | |
except Exception as e: | |
logger.error(e) | |
def new_convo(client, issue, language, username, is_comparison, model_one, model_two=None): | |
convo = { | |
"start_timestamp": dt.datetime.now(tz=dt.timezone.utc), | |
"issue": issue, | |
"language": language, | |
"username": username, | |
"is_comparison": is_comparison, | |
"model_one": model_one, | |
"model_two": model_two, | |
} | |
db = client[DB_SCHEMA] | |
convos = db[DB_CONVOS] | |
convo_id = convos.insert_one(convo).inserted_id | |
logger.info(f"DBUTILS: new convo id is {convo_id}") | |
st.session_state['convo_id'] = convo_id | |
def new_comparison(client, prompt_timestamp, completion_timestamp, | |
chat_history, prompt, completionA, completionB, | |
source="webapp", subset=None | |
): | |
comparison = { | |
"prompt_timestamp": prompt_timestamp, | |
"completion_timestamp": completion_timestamp, | |
"source": source, | |
"subset": subset, | |
"model_one_args": { | |
'temperature':0.8 | |
}, | |
"model_two_args": { | |
'temperature':0.8 | |
}, | |
"convo_id": st.session_state['convo_id'], | |
"chat_history": chat_history, | |
"prompt": prompt, | |
"compeltion_model_one": completionA, | |
"compeltion_model_two": completionB, | |
} | |
db = client[DB_SCHEMA] | |
comparisons = db[DB_COMPLETIONS] | |
comparison_id = comparisons.insert_one(comparison).inserted_id | |
logger.info(f"DBUTILS: new comparison id is {comparison_id}") | |
st.session_state['comparison_id'] = comparison_id | |
def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner): | |
battle = { | |
"battle_timestamp": dt.datetime.now(tz=dt.timezone.utc), | |
"comparison_id": comparison_id, | |
"convo_id": convo_id, | |
"username": username, | |
"model_one": model_one, | |
"model_two": model_two, | |
"winner": winner, | |
} | |
db = client[DB_SCHEMA] | |
battles = db[DB_BATTLES] | |
battle_id = battles.insert_one(battle).inserted_id | |
logger.info(f"DBUTILS: new battle id is {battle_id}") | |
def new_completion_error(client, comparison_id, username, model): | |
error = { | |
"error_timestamp": dt.datetime.now(tz=dt.timezone.utc), | |
"comparison_id": comparison_id, | |
"username": username, | |
"model": model, | |
} | |
db = client[DB_SCHEMA] | |
errors = db[DB_ERRORS] | |
error_id = errors.insert_one(error).inserted_id | |
logger.info(f"DBUTILS: new error id is {error_id}") | |
def get_non_assesed_comparison(client, username): | |
from bson.son import SON | |
pipeline = [ | |
{'$lookup': { | |
'from': DB_BATTLES, | |
'localField': '_id', | |
'foreignField': 'comparison_id', | |
"pipeline": [ | |
{"$match": {"username":username}}, | |
], | |
'as': 'battles' | |
}}, | |
{'$lookup': { | |
'from': DB_CONVOS, | |
'localField': 'convo_id', | |
'foreignField': '_id', | |
'as': 'convo_info' | |
}}, | |
{"$match":{ | |
"battles": {"$size":0}, | |
}}, | |
{"$addFields": { | |
"is_manual": { | |
"$cond":[ | |
{"$eq": ["$source","manual"]}, | |
1, | |
0 | |
] | |
}, | |
"is_eval":{ | |
"$cond":[ | |
{"$eq": ["$subset","eval"]}, | |
1, | |
0 | |
] | |
}, | |
"priority": {"$sum": ["is_manual","is_eval"]} | |
}}, | |
{"$sort": SON([ | |
("priority", -1), | |
("prompt_timestamp", 1), | |
("convo_id", 1), | |
]) | |
}, | |
{"$limit": 1} | |
] | |
db = client[DB_SCHEMA] | |
return list(db[DB_COMPLETIONS].aggregate(pipeline)) | |