Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import warnings | |
import numpy as np | |
import torch.nn.functional as F | |
from enum import Enum | |
from copy import deepcopy | |
from sklearn.utils.extmath import softmax | |
from sentence_transformers import SentenceTransformer | |
from utils import * | |
# Environment setup for HF docker image | |
try: | |
os.mkdir('./cache') | |
except FileExistsError: | |
# Use existing cache dir | |
pass | |
# Create embeddings from example texts | |
# Guessing which environ var is correct | |
os.environ['SENTENCE_TRANSFORMERS_HOME'] = './cache' | |
os.environ['TRANSFORMERS_CACHE']='./cache' | |
# Load model | |
encodingModel = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
# Left wing | |
with open('./manifesto-left.txt', 'r') as f: | |
LeftWingStr=f.read() | |
# Right wing | |
with open('./manifesto-right.txt', 'r') as f: | |
RightWingStr=f.read() | |
LWPair=[LeftWingStr, encodingModel.encode(LeftWingStr)] | |
RWPair=[RightWingStr, encodingModel.encode(RightWingStr)] | |
# Target distribution within documents: [p(left), p(right)] | |
TARGET_DISTRIBUTION = [0.5, 0.5] | |
# Controls the weight of the initial relevance score (0: ignore initial score, 1: only uses initial score) | |
LAMBDA = 0.5 | |
# The different modes our ranking algorithm can run in | |
class RankingModes(Enum): | |
DIVERSIFY = "diversify" | |
NEUTRALISE = "neutralise" | |
def fairScore(prob_scores:list, target:list) -> float: | |
similarity = 1 - get_jsd_distance(prob_scores, target) | |
return similarity | |
def diversify(candidates: list, candidates_representation: dict, target: list) -> dict: | |
accumulator = np.zeros(len(target)) | |
remaining = candidates.copy() | |
diversified = {} | |
for index in range(len(candidates)): | |
best_candidate = None | |
best_score = -np.inf | |
for candidate in remaining: | |
representation = candidates_representation[candidate['id']] | |
temp_accumulator = np.add(accumulator,representation) / 2 | |
score = fairScore(temp_accumulator, target) | |
if score > best_score: | |
best_candidate = candidate | |
best_score = score | |
accumulator = temp_accumulator | |
remaining.remove(best_candidate) | |
best_candidate_id = best_candidate['id'] | |
diversified[best_candidate_id] = best_score | |
return diversified | |
def textParser(item: dict) -> str: | |
textkeys=['title', 'description', 'text'] | |
output_text='' | |
itemkeys=[x for x in item.keys() if x in textkeys] | |
for key in itemkeys: | |
output_text=output_text+item[key]+' ' | |
assert output_text!='', 'text string is empty' | |
return output_text | |
################ TEST THIS!!!!!!!!!!!!!!!!!!!!!!! | |
def rankingfunc(inputJSON: dict, k: int = 10, mode: str = RankingModes.DIVERSIFY, debug: bool = False) -> dict: | |
''' | |
Rank a set of social media posts using our ranking algorithm | |
Inputs: | |
inputJSON (dict): JSON dict from the web browser plugin, following the | |
provided competition spec at https://github.com/HumanCompatibleAI/ranking-challenge | |
k (int): We only mess with the ranking of the first k items in the feed, to avoid | |
unduly reducing engagement. | |
mode (str): The ranker algorithm mode. Options include 'diversify' or 'neutralise'. | |
debug (bool): If set, will also return extra debugging info in the return struct | |
Returns: | |
(dict): JSON dict of re-ranked and new post IDs, following the | |
provided competition spec at https://github.com/HumanCompatibleAI/ranking-challenge | |
''' | |
assert k > 0, "k must be a positive integer greater than 0, but was {k}" | |
assert mode in RankingModes, f"mode must be in {RankingModes}, but was {mode}" | |
# Extract text documents and get embeddings | |
candidates = inputJSON['items'] | |
if len(candidates) < k: | |
warnings.warn(f"k truncated from {k} to {len(candidates)} due to only that many posts being passed") | |
k = min(k, len(candidates)) | |
if (debug): | |
print("Reranking top ", k) | |
tail = candidates[k:] | |
candidates = candidates[:k] | |
# There's a TypeError in textParser | |
r""" | |
Traceback (most recent call last): | |
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 1498, in __call__ | |
return self.wsgi_app(environ, start_response) | |
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 1476, in wsgi_app | |
response = self.handle_exception(e) | |
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask_cors\extension.py", line 176, in wrapped_function | |
return cors_after_request(app.make_response(f(*args, **kwargs))) | |
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 1473, in wsgi_app | |
response = self.full_dispatch_request() | |
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 882, in full_dispatch_request | |
rv = self.handle_user_exception(e) | |
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask_cors\extension.py", line 176, in wrapped_function | |
return cors_after_request(app.make_response(f(*args, **kwargs))) | |
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 880, in full_dispatch_request | |
rv = self.dispatch_request() | |
File "C:\Users\snoswell\.conda\envs\py38\lib\site-packages\flask\app.py", line 865, in dispatch_request | |
return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return] | |
File "C:\Development\ChaiPRC-WH\my_web_app.py", line 39, in perform_ranking | |
results = rankingfunc(post_data, k=k, mode=mode, debug=debug) | |
File "C:\Development\ChaiPRC-WH\reranker.py", line 129, in rankingfunc | |
texts = [textParser(x) for x in candidates] | |
File "C:\Development\ChaiPRC-WH\reranker.py", line 129, in <listcomp> | |
texts = [textParser(x) for x in candidates] | |
File "C:\Development\ChaiPRC-WH\reranker.py", line 92, in textParser | |
output_text=output_text+itemkeys[key]+' ' | |
""" | |
#texts = [textParser(x) for x in candidates] | |
texts = [x['text'] for x in candidates] | |
embeddings = encodingModel.encode(texts) | |
#print(LWPair[1].shape) # (768, ) | |
#print(RWPair[1].shape) # (768, ) | |
#print(embeddings.shape) # (5, 768) | |
# Compute cosine ranks based on similarity to political personas | |
lw_cs = F.cosine_similarity(torch.tensor(LWPair[1]), torch.tensor(embeddings)) | |
rw_cs = F.cosine_similarity(torch.tensor(RWPair[1]), torch.tensor(embeddings)) | |
#convert to be accessible by post id: | |
ids = [x['id'] for x in candidates] | |
zip_candidates = zip(ids, zip(lw_cs, rw_cs)) | |
candidates_representation = {id: F.softmax(torch.stack([(l + 1.0) * 0.5, (r + 1.0) * 0.5]), dim=0) for id, (l, r) in zip_candidates} | |
if debug: | |
print("Left Wing: ", lw_cs) | |
print("Right Wing: ", rw_cs) | |
# Obtain initial ranking scores from platform | |
initial_scores = decayFunction(candidates) | |
# original_ranking = [c["id"] for c in candidates] | |
if debug: | |
print(initial_scores) | |
diversity_scores = [] | |
if mode == RankingModes.DIVERSIFY: | |
diversity_scores = diversify(candidates, candidates_representation, TARGET_DISTRIBUTION) | |
for index in range(len(candidates)): | |
# Higher values are better for relevance and fairness | |
relevance = initial_scores[index] | |
source = [(lw_cs[index] + 1.0) * 0.5, (rw_cs[index] + 1.0) * 0.5] | |
source = F.softmax(torch.stack(source), dim=0) | |
fairness = 0 | |
if mode == RankingModes.DIVERSIFY: | |
# Diversification: | |
fairness = diversity_scores[candidates[index]['id']] | |
elif mode == RankingModes.NEUTRALISE: | |
# Neutralization: | |
fairness = fairScore(source, TARGET_DISTRIBUTION) | |
else: | |
raise ValueError(f"Unknown ranking algorithm mode: {mode}") | |
new_score = linearCombination(relevance, fairness, LAMBDA) | |
candidates[index]['score'] = new_score | |
if debug: | |
print("index:", index) | |
print("left/right scores:", source) | |
print("Fairness score:", fairness) | |
print("Previous score:", relevance) | |
print("New score:", new_score) | |
# Reverse sort because higher is better for relevance and fairness | |
reranked = sorted(candidates, key=lambda x: x['score'], reverse=True) | |
reranked_ids = [item['id'] for item in reranked] | |
reranked_ids.extend([item['id'] for item in tail]) | |
final_ranking = reranked_ids | |
# TODO ajs 15/Apr/2024 Find a way to source high-quality out-of-feed posts, then incorporate them into the fusion algorithm | |
output_results = { | |
"ranked_ids": final_ranking, | |
"new_items": [] | |
} | |
return output_results | |