File size: 5,417 Bytes
30b7bc6 d19c70c 4dccf1d 892f774 f50f18c 5fef682 4dccf1d 30b7bc6 f50f18c 30b7bc6 f50f18c 4dccf1d f50f18c 4dccf1d d19c70c 4dccf1d 23a56a2 5fef682 4dccf1d 5fef682 892f774 5fef682 892f774 5fef682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import logging
from sqlalchemy import create_engine, Column, Integer, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from datetime import datetime
import pandas as pd
import numpy as np
from datasets import load_dataset
from rating_systems import compute_elo, compute_bootstrap_elo, get_median_elo_from_bootstrap
def is_running_in_space():
return "SPACE_ID" in os.environ
if is_running_in_space():
DATABASE_URL = "sqlite:///./data/hf-votes.db"
else:
DATABASE_URL = "sqlite:///./data/local2.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Database model
class Vote(Base):
__tablename__ = "votes"
id = Column(Integer, primary_key=True, index=True)
image_id = Column(String, index=True)
model_a = Column(String)
model_b = Column(String)
winner = Column(String)
user_id = Column(String, index=True)
fpath_a = Column(String)
fpath_b = Column(String)
timestamp = Column(DateTime, default=datetime.utcnow)
Base.metadata.create_all(bind=engine)
# Dependency for database session
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def fill_database_once(dataset_name="bgsys/votes_datasets_test2"):
with SessionLocal() as db:
# Check if the database is already filled
if db.query(Vote).first() is None:
dataset = load_dataset(dataset_name)
for record in dataset['train']:
# Ensure the timestamp is a string
timestamp_str = record.get("timestamp", datetime.utcnow().isoformat())
if not isinstance(timestamp_str, str):
timestamp_str = datetime.utcnow().isoformat()
vote_data = {
"image_id": record.get("image_id", ""),
"model_a": record.get("model_a", ""),
"model_b": record.get("model_b", ""),
"winner": record.get("winner", ""),
"user_id": record.get("user_id", ""),
"fpath_a": record.get("fpath_a", ""),
"fpath_b": record.get("fpath_b", ""),
"timestamp": datetime.fromisoformat(timestamp_str)
}
db_vote = Vote(**vote_data)
db.add(db_vote)
db.commit()
logging.info("Database filled with data from Hugging Face dataset: %s", dataset_name)
else:
logging.info("Database already filled, skipping dataset loading.")
def add_vote(vote_data):
with SessionLocal() as db:
db_vote = Vote(**vote_data)
db.add(db_vote)
db.commit()
db.refresh(db_vote)
logging.info("Vote registered with ID: %s, using database: %s", db_vote.id, DATABASE_URL)
return {"id": db_vote.id, "user_id": db_vote.user_id, "timestamp": db_vote.timestamp}
# Function to get all votes
def get_all_votes():
with SessionLocal() as db:
votes = db.query(Vote).all()
return votes
# Function to compute Elo scores
def compute_elo_scores():
valid_models = ["Photoroom", "RemoveBG", "BRIA RMBG 2.0"]
with SessionLocal() as db:
votes = db.query(Vote).all()
data = {
"model_a": [vote.model_a for vote in votes],
"model_b": [vote.model_b for vote in votes],
"winner": [vote.winner for vote in votes]
}
df = pd.DataFrame(data)
init_size = df.shape[0]
# Remove votes missing model_a, model_b or winner info
df.dropna(subset=["model_a", "model_b", "winner"], inplace=True)
# Validate models and winner
def is_valid_vote(row):
if row["model_a"] not in valid_models or row["model_b"] not in valid_models:
return False
if row["winner"] not in ["model_a", "model_b", "tie"]:
return False
return True
df = df[df.apply(is_valid_vote, axis=1)]
logging.info("Initial votes count: %d", init_size)
logging.info("Votes count after validation: %d", df.shape[0])
# Seed the random number generator for reproducibility
np.random.seed(42)
bootstrap_elo_scores = compute_bootstrap_elo(df)
median_elo_scores = get_median_elo_from_bootstrap(bootstrap_elo_scores)
model_rating_q025 = bootstrap_elo_scores.quantile(0.025)
model_rating_q975 = bootstrap_elo_scores.quantile(0.975)
variance = bootstrap_elo_scores.var()
return median_elo_scores, model_rating_q025, model_rating_q975, variance
# Function to compute the number of votes for each model
def compute_votes_per_model():
with SessionLocal() as db:
votes = db.query(Vote).all()
model_vote_count = {}
for vote in votes:
if vote.winner == "model_a":
model = vote.model_a
elif vote.winner == "model_b":
model = vote.model_b
else:
continue
if model not in model_vote_count:
model_vote_count[model] = 0
model_vote_count[model] += 1
return model_vote_count
|