{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluate Embedding Similarity Metrics" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import openai, numpy as np" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def cosine_distance(a, b):\n", " \"\"\"Calculate the cosine distance between two numpy arrays.\n", " \n", " Parameters:\n", " a (numpy array): First input array.\n", " b (numpy array): Second input array.\n", " \n", " Returns:\n", " float: Cosine distance between a and b.\n", " \"\"\"\n", " # Calculate dot product and magnitudes of the input arrays\n", " dot = np.dot(a, b)\n", " a_mag = np.linalg.norm(a)\n", " b_mag = np.linalg.norm(b)\n", " \n", " if np.isclose(a_mag, 0, rtol=1e-9, atol=1e-12):\n", " print(f\"a_mag is very small: {a_mag}\")\n", " if np.isclose(b_mag, 0, rtol=1e-9, atol=1e-12):\n", " print(f\"b_mag is very small: {b_mag}\")\n", " \n", " # Calculate and return the cosine distance\n", " return 1.0 - (dot / (a_mag * b_mag))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def semantically_similar(string1, string2):\n", " response = openai.Embedding.create(\n", " input=[string1, string2],\n", " engine=\"text-similarity-davinci-001\"\n", " )\n", " embedding_a = response['data'][0]['embedding']\n", " embedding_b = response['data'][1]['embedding']\n", " similarity_score = cosine_distance(embedding_a, embedding_b)\n", " print(f\"similarity: {similarity_score}\")\n", "\n", " return similarity_score < 0.2" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "similarity: 0.22501948669661986\n", "similarity: 0.2318907843871436\n", "similarity: 0.12933868208210475\n", "similarity: 0.10699853725782704\n" ] }, { "data": { "text/plain": [ "(True,)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "semantically_similar(\"fight a war\", \"water supply\"),\n", "semantically_similar(\"fight a war\", \"solar energy\"),\n", "semantically_similar(\"fight a war\", \"defend a country\"),\n", "semantically_similar(\"fight a war\", \"win a battle\")," ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "similarity: 0.2496415604648079\n" ] }, { "data": { "text/plain": [ "False" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "semantically_similar(\"the sky is blue\", \"I like to eat\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "similarity: 0.10193029028713485\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "semantically_similar(\"the cat meows\", \"the feline animal says\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "similarity: 0.19759407795526762\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "semantically_similar(\"what is the best way to win a war?\", \"strategizing a war\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "similarity: 0.1949772795717004\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "semantically_similar(\"what is the best way to win a war?\", \"fight a war\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "chain", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }