File size: 4,873 Bytes
dff2993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import random
import json
class EloRank:
def __init__(self, initial_rating=1000, k_factor=32):
"""
Initialize the EloRank class.
:param initial_rating: Initial ELO rating for each model.
:param k_factor: The K-factor that determines the sensitivity of rating changes.
"""
self.ratings = {}
self.initial_rating = initial_rating
self.k_factor = k_factor
self.wins = {}
def add_model(self, model_id):
"""
Add a new model with the initial rating.
:param model_id: Unique identifier for the model.
"""
self.ratings[model_id] = self.initial_rating
self.wins[model_id] = 0
def record_match(self, winner, loser):
"""
Update the ratings based on a match result.
:param winner: Model ID of the winner.
:param loser: Model ID of the loser.
"""
rating_winner = self.ratings[winner]
rating_loser = self.ratings[loser]
expected_winner = self.expected_score(rating_winner, rating_loser)
expected_loser = self.expected_score(rating_loser, rating_winner)
self.ratings[winner] += self.k_factor * (1 - expected_winner)
self.ratings[loser] += self.k_factor * (0 - expected_loser)
# Update win count
self.wins[winner] += 1
def expected_score(self, rating_a, rating_b):
"""
Calculate the expected score for a model.
:param rating_a: Rating of model A.
:param rating_b: Rating of model B.
:return: Expected score.
"""
return 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
def get_rating(self, model_id):
"""
Get the current rating of a model.
:param model_id: Unique identifier for the model.
:return: Current rating of the model.
"""
return self.ratings.get(model_id, None)
def get_wins(self, model_id):
"""
Get the number of wins of a model.
:param model_id: Unique identifier for the model.
:return: Number of wins of the model.
"""
return self.wins.get(model_id, 0)
def get_top_models(self, n=2):
"""
Get the top N models by rating.
:param n: Number of top models to retrieve.
:return: List of model IDs of the top models.
"""
return sorted(self.ratings, key=self.ratings.get, reverse=True)[:n]
def sample_next_match(self):
"""
Sample the next match based on the probability proportional to the current rating.
This approach helps accelerate the convergence of ranking.
:return: Tuple of two model IDs for the next match.
"""
model_ids = list(self.ratings.keys())
probabilities = [self.ratings[model_id] for model_id in model_ids]
total_rating = sum(probabilities)
probabilities = [rating / total_rating for rating in probabilities]
# Sample two different models for the next match
next_match = random.choices(model_ids, probabilities, k=2)
while next_match[0] == next_match[1]:
next_match = random.choices(model_ids, probabilities, k=2)
return tuple(next_match)
def process_match_records(self, file_path):
"""
Process match records from a JSON file and update ratings and win counts accordingly.
:param file_path: Path to the JSON file containing match records.
"""
with open(file_path, 'r') as file:
match_records = json.load(file)
for record in match_records:
winner = record['winner']
model_1 = record['model_1']
model_2 = record['model_2']
# Add models if they are not already added
if model_1 not in self.ratings:
self.add_model(model_1)
if model_2 not in self.ratings:
self.add_model(model_2)
# Record the match result
if winner == model_1:
self.record_match(model_1, model_2)
elif winner == model_2:
self.record_match(model_2, model_1)
# # Example Usage
# e = EloRank()
# e.add_model('model_A')
# e.add_model('model_B')
# e.add_model('model_C')
# e.record_match('model_A', 'model_B')
# print(e.get_rating('model_A')) # Should be greater than the initial rating
# print(e.get_rating('model_B')) # Should be less than the initial rating
# print(e.get_top_models(2)) # Get the top 2 models
# print(e.sample_next_match()) # Sample the next match based on ratings
# # Process match records from a JSON file
# e.process_match_records('match_records.json')
# print(e.get_wins('model_A')) # Get the number of wins for model_A
|