{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "m_Oa5BjYS_UF" }, "source": [ "# Semantic search with FAISS (TensorFlow)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sjgafp3ZS_UH" }, "outputs": [], "source": [ "!pip install datasets evaluate transformers[sentencepiece]\n", "!pip install faiss-cpu" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UWljEBCWS_UR" }, "outputs": [], "source": [ "import pandas as pd\n", "from datasets import load_from_disk\n", "from transformers import AutoTokenizer, TFAutoModel\n", "\n", "Drugs = ['Acne', 'Adhd', 'Allergies', 'Anaemia', 'Angina', 'Appetite',\n", " 'Arthritis', 'Constipation', 'Contraception', 'Dandruff',\n", " 'Diabetes', 'Digestion', 'Fever', 'Fungal', 'General', 'Glaucoma',\n", " 'Gout', 'Haematopoiesis', 'Haemorrhoid', 'Hyperpigmentation',\n", " 'Hypertension', 'Hyperthyroidism', 'Hypnosis', 'Hypothyroidism',\n", " 'Infection', 'Migraine', 'Osteoporosis', 'Pain', 'Psychosis',\n", " 'Schizophrenia', 'Supplement', 'Thrombolysis', 'Viral', 'Wound']\n", "\n", "model_ckpt = \"sentence-transformers/multi-qa-mpnet-base-dot-v1\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n", "model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)\n", "\n", "def cls_pooling(model_output):\n", " return model_output.last_hidden_state[:, 0]\n", "\n", "def get_embeddings(text_list):\n", " encoded_input = tokenizer(\n", " text_list, padding=True, truncation=True, return_tensors=\"tf\"\n", " )\n", " encoded_input = {k: v for k, v in encoded_input.items()}\n", " model_output = model(**encoded_input)\n", " return cls_pooling(model_output)\n", "\n", "\n", "embeddings_dataset = load_from_disk(\"/content/drive/MyDrive/Drugs\")\n", "embeddings_dataset.add_faiss_index(column=\"embeddings\")\n", "\n", "def recommendations(question):\n", " question_embedding = get_embeddings([question]).numpy()\n", " scores, samples = embeddings_dataset.get_nearest_examples(\n", " \"embeddings\", question_embedding, k=5\n", " )\n", " samples_df = pd.DataFrame.from_dict(samples)\n", " samples_df[\"scores\"] = scores\n", " samples_df.sort_values(\"sc>ores\", ascending=False, inplace=True)\n", " return samples_df[['Drug_Name', 'Reason', 'scores']]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fGRtwhvTcZ9t" }, "outputs": [], "source": [ "question = \"moderate acne\"\n", "recommendations(question)" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }