reranker-v1 / reranker.py
Aaron Snoswell
Merge pull request #8 from WillSH97/add-title-desc-text
e4f2e91 unverified
import os
import torch
import warnings
import numpy as np
import torch.nn.functional as F
from enum import Enum
from copy import deepcopy
from sklearn.utils.extmath import softmax
from sentence_transformers import SentenceTransformer
from utils import *
# Environment setup for HF docker image
try:
os.mkdir('./cache')
except FileExistsError:
# Use existing cache dir
pass
# Create embeddings from example texts
# Guessing which environ var is correct
os.environ['SENTENCE_TRANSFORMERS_HOME'] = './cache'
os.environ['TRANSFORMERS_CACHE']='./cache'
# Load model
encodingModel = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
# Left wing
with open('./manifesto-left.txt', 'r') as f:
LeftWingStr=f.read()
# Right wing
with open('./manifesto-right.txt', 'r') as f:
RightWingStr=f.read()
LWPair=[LeftWingStr, encodingModel.encode(LeftWingStr)]
RWPair=[RightWingStr, encodingModel.encode(RightWingStr)]
# Target distribution within documents: [p(left), p(right)]
TARGET_DISTRIBUTION = [0.5, 0.5]
# Controls the weight of the initial relevance score (0: ignore initial score, 1: only uses initial score)
LAMBDA = 0.5
# The different modes our ranking algorithm can run in
class RankingModes(Enum):
DIVERSIFY = "diversify"
NEUTRALISE = "neutralise"
def fairScore(prob_scores:list, target:list) -> float:
similarity = 1 - get_jsd_distance(prob_scores, target)
return similarity
def diversify(candidates: list, candidates_representation: dict, target: list) -> dict:
accumulator = np.zeros(len(target))
remaining = candidates.copy()
diversified = {}
for index in range(len(candidates)):
best_candidate = None
best_score = -np.inf
for candidate in remaining:
representation = candidates_representation[candidate['id']]
temp_accumulator = np.add(accumulator,representation) / 2
score = fairScore(temp_accumulator, target)
if score > best_score:
best_candidate = candidate
best_score = score
accumulator = temp_accumulator
remaining.remove(best_candidate)
best_candidate_id = best_candidate['id']
diversified[best_candidate_id] = best_score
return diversified
def textParser(item: dict) -> str:
textkeys=['title', 'description', 'text']
output_text=''
itemkeys=[x for x in item.keys() if x in textkeys]
for key in itemkeys:
output_text=output_text+item[key]+' '
assert output_text!='', 'text string is empty'
return output_text
################ TEST THIS!!!!!!!!!!!!!!!!!!!!!!!
def rankingfunc(inputJSON: dict, k: int = 10, mode: str = RankingModes.DIVERSIFY, debug: bool = False) -> dict:
'''
Rank a set of social media posts using our ranking algorithm
Inputs:
inputJSON (dict): JSON dict from the web browser plugin, following the
provided competition spec at https://github.com/HumanCompatibleAI/ranking-challenge
k (int): We only mess with the ranking of the first k items in the feed, to avoid
unduly reducing engagement.
mode (str): The ranker algorithm mode. Options include 'diversify' or 'neutralise'.
debug (bool): If set, will also return extra debugging info in the return struct
Returns:
(dict): JSON dict of re-ranked and new post IDs, following the
provided competition spec at https://github.com/HumanCompatibleAI/ranking-challenge
'''
assert k > 0, "k must be a positive integer greater than 0, but was {k}"
assert mode in RankingModes, f"mode must be in {RankingModes}, but was {mode}"
# Extract text documents and get embeddings
candidates = inputJSON['items']
if len(candidates) < k:
warnings.warn(f"k truncated from {k} to {len(candidates)} due to only that many posts being passed")
k = min(k, len(candidates))
if (debug):
print("Reranking top ", k)
tail = candidates[k:]
candidates = candidates[:k]
# There's a TypeError in textParser
r"""
Traceback (most recent call last):
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 1498, in __call__
return self.wsgi_app(environ, start_response)
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 1476, in wsgi_app
response = self.handle_exception(e)
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask_cors\extension.py", line 176, in wrapped_function
return cors_after_request(app.make_response(f(*args, **kwargs)))
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 1473, in wsgi_app
response = self.full_dispatch_request()
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 882, in full_dispatch_request
rv = self.handle_user_exception(e)
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask_cors\extension.py", line 176, in wrapped_function
return cors_after_request(app.make_response(f(*args, **kwargs)))
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 880, in full_dispatch_request
rv = self.dispatch_request()
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 865, in dispatch_request
return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return]
File "C:\Development\ChaiPRC-WH\my_web_app.py", line 39, in perform_ranking
results = rankingfunc(post_data, k=k, mode=mode, debug=debug)
File "C:\Development\ChaiPRC-WH\reranker.py", line 129, in rankingfunc
texts = [textParser(x) for x in candidates]
File "C:\Development\ChaiPRC-WH\reranker.py", line 129, in <listcomp>
texts = [textParser(x) for x in candidates]
File "C:\Development\ChaiPRC-WH\reranker.py", line 92, in textParser
output_text=output_text+itemkeys[key]+' '
"""
#texts = [textParser(x) for x in candidates]
texts = [x['text'] for x in candidates]
embeddings = encodingModel.encode(texts)
#print(LWPair[1].shape) # (768, )
#print(RWPair[1].shape) # (768, )
#print(embeddings.shape) # (5, 768)
# Compute cosine ranks based on similarity to political personas
lw_cs = F.cosine_similarity(torch.tensor(LWPair[1]), torch.tensor(embeddings))
rw_cs = F.cosine_similarity(torch.tensor(RWPair[1]), torch.tensor(embeddings))
#convert to be accessible by post id:
ids = [x['id'] for x in candidates]
zip_candidates = zip(ids, zip(lw_cs, rw_cs))
candidates_representation = {id: F.softmax(torch.stack([(l + 1.0) * 0.5, (r + 1.0) * 0.5]), dim=0) for id, (l, r) in zip_candidates}
if debug:
print("Left Wing: ", lw_cs)
print("Right Wing: ", rw_cs)
# Obtain initial ranking scores from platform
initial_scores = decayFunction(candidates)
# original_ranking = [c["id"] for c in candidates]
if debug:
print(initial_scores)
diversity_scores = []
if mode == RankingModes.DIVERSIFY:
diversity_scores = diversify(candidates, candidates_representation, TARGET_DISTRIBUTION)
for index in range(len(candidates)):
# Higher values are better for relevance and fairness
relevance = initial_scores[index]
source = [(lw_cs[index] + 1.0) * 0.5, (rw_cs[index] + 1.0) * 0.5]
source = F.softmax(torch.stack(source), dim=0)
fairness = 0
if mode == RankingModes.DIVERSIFY:
# Diversification:
fairness = diversity_scores[candidates[index]['id']]
elif mode == RankingModes.NEUTRALISE:
# Neutralization:
fairness = fairScore(source, TARGET_DISTRIBUTION)
else:
raise ValueError(f"Unknown ranking algorithm mode: {mode}")
new_score = linearCombination(relevance, fairness, LAMBDA)
candidates[index]['score'] = new_score
if debug:
print("index:", index)
print("left/right scores:", source)
print("Fairness score:", fairness)
print("Previous score:", relevance)
print("New score:", new_score)
# Reverse sort because higher is better for relevance and fairness
reranked = sorted(candidates, key=lambda x: x['score'], reverse=True)
reranked_ids = [item['id'] for item in reranked]
reranked_ids.extend([item['id'] for item in tail])
final_ranking = reranked_ids
# TODO ajs 15/Apr/2024 Find a way to source high-quality out-of-feed posts, then incorporate them into the fusion algorithm
output_results = {
"ranked_ids": final_ranking,
"new_items": []
}
return output_results