{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "7a99e5ef-7aee-4ec0-8099-82dac93d4614", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/mnt/c/Users/hew7/Documents/venvs/ranking-challenge/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import requests\n", "import numpy as np\n", "from numpy.linalg import norm\n", "from scipy.stats import rankdata\n", "from sentence_transformers import SentenceTransformer\n", "from copy import deepcopy\n", "\n", "#sample data\n", "from sample_data import BASIC_EXAMPLE" ] }, { "cell_type": "code", "execution_count": 2, "id": "01db2542-4293-4df5-91dd-6165e2180f05", "metadata": {}, "outputs": [], "source": [ "encodingModel = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "87892ace-2df7-4bc6-9569-b8b99b13b744", "metadata": {}, "outputs": [], "source": [ "#create embeddings from example texts\n", "\n", "#left wing\n", "with open('/mnt/c/Users/hew7/Documents/Git/ChaiProsocialRankingChallenge/flask-test/manifesto-left.txt', 'r') as f:\n", " LeftWingStr=f.read()\n", "\n", "#right wing\n", "with open('/mnt/c/Users/hew7/Documents/Git/ChaiProsocialRankingChallenge/flask-test/manifesto-right.txt', 'r') as f:\n", " RightWingStr=f.read()" ] }, { "cell_type": "code", "execution_count": 4, "id": "e7cf9ca9-402e-4ba2-be33-f9356a0c6b9f", "metadata": {}, "outputs": [], "source": [ "LWPair=[LeftWingStr, encodingModel.encode(LeftWingStr)]" ] }, { "cell_type": "code", "execution_count": 5, "id": "4537352f-efa6-487f-a3bb-4f10bf439190", "metadata": {}, "outputs": [], "source": [ "RWPair=[RightWingStr, encodingModel.encode(RightWingStr)]" ] }, { "cell_type": "code", "execution_count": 6, "id": "ae4e52ea-cbb3-462b-8b34-58c83bfb6dbe", "metadata": {}, "outputs": [], "source": [ "#pulling in examples\n", "example_texts = [x['text'] for x in BASIC_EXAMPLE['items']]" ] }, { "cell_type": "code", "execution_count": 7, "id": "772da89d-e4d6-419c-a4ab-3f0b023ac068", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['this is the worst thing I have ever seen!',\n", " 'this is amazing!',\n", " 'this thing is ok.']" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example_texts" ] }, { "cell_type": "code", "execution_count": 8, "id": "ce6ffea3-a5b2-42b9-9895-82cc18431272", "metadata": {}, "outputs": [], "source": [ "embeddings = encodingModel.encode(example_texts)" ] }, { "cell_type": "code", "execution_count": 9, "id": "83f809cf-9616-4f1d-8cf1-e1e1c15fefe6", "metadata": {}, "outputs": [], "source": [ "#cosine similarity \n", "\n", "def cosineSim(x, y) -> float: #type hint for np array I think - but I'll figure it out later\n", " xArray=np.array(x)\n", " yArray=np.array(y)\n", " cosine=np.dot(xArray,yArray)/(norm(xArray)*norm(yArray))\n", " return cosine" ] }, { "cell_type": "code", "execution_count": 10, "id": "363c66d4-295b-4eda-9771-d1688260619e", "metadata": {}, "outputs": [], "source": [ "#ranking func, purely cosine similarity ----- KINDA JANKY\n", "def cosineRank(lhs: list, rhs: list, ) -> list:\n", " '''\n", " returns list of rankings in order of embeddings\n", " '''\n", " similarity_list=[]\n", " for candidate in rhs:\n", " similarity_list.append(cosineSim(lhs, candidate))\n", " results = rankdata(similarity_list) - 1\n", " return results" ] }, { "cell_type": "code", "execution_count": 11, "id": "41a04657-8020-4d27-a48f-28e4bd5795b7", "metadata": {}, "outputs": [], "source": [ "def sort_text_cosine(LHSEmbedding, RHSEmbeddingList, RHSTextList) -> list:\n", " result_order = cosineRank(LHSEmbedding, RHSEmbeddingList)\n", " print(result_order)\n", " output = [RHSTextList[int(x)] for x in result_order]\n", " return output\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "36561978-bf59-4d9d-b6be-b78307fccd0c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 0. 2.]\n" ] }, { "data": { "text/plain": [ "['this is amazing!',\n", " 'this is the worst thing I have ever seen!',\n", " 'this thing is ok.']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sort_text_cosine(LWPair[1],embeddings, example_texts)" ] }, { "cell_type": "code", "execution_count": 13, "id": "7e52c93e-135a-4ee5-b41e-1d7d0308f7d0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 1. 2.]\n" ] }, { "data": { "text/plain": [ "['this is the worst thing I have ever seen!',\n", " 'this is amazing!',\n", " 'this thing is ok.']" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sort_text_cosine(RWPair[1],embeddings, example_texts)" ] }, { "cell_type": "code", "execution_count": 14, "id": "f2a3bf61-5614-4502-8a39-2691761bb12e", "metadata": {}, "outputs": [], "source": [ "#trying to write a function that inputs and outputs dicts (start to end for API)\n", "def rankingfunc(inputJSON: dict) -> dict:\n", " '''\n", " WIP - super gross func but it works for now\n", " \n", " Final ranking func using previously defined encodingModel and cosine sim to rank similarity to left-wing\n", " or right-wing text file. Tested on provided example json from sample_data. Returns identically structured\n", " json with reordered results.\n", " '''\n", " \n", " #change LHS based on userID:\n", " if inputJSON['session']['user_id'] in ['193a9e01-8849-4e1f-a42a-a859fa7f2ad3']: #change this list to be for all users selected for left_wing\n", " LHS=LWPair\n", " else:\n", " LHS=RWPair\n", "\n", " #prepare data and get embeddings\n", " candidates = inputJSON['items']\n", " texts=[x['text'] for x in candidates]\n", " embeddings=encodingModel.encode(texts)\n", "\n", " #rerank\n", " item_rank=cosineRank(LHS[1], embeddings)\n", " for index in range(len(candidates)):\n", " candidates[index]['rank']=item_rank[index]\n", " output_list = sorted(candidates, key=lambda x: x['rank'])\n", " for i in output_list:\n", " del i['rank']\n", " \n", " #prep data for export\n", " output_dict=deepcopy(inputJSON)\n", " output_dict['items']=output_list\n", "\n", " return output_dict\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "effed0f2-cea2-4f3b-8739-f71d4cce8297", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "ranking-challenge", "language": "python", "name": "ranking-challenge" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }