{ "cells": [ { "cell_type": "markdown", "id": "ec403ba5-1356-46b7-a14f-86bf7db0c5b4", "metadata": {}, "source": [ "## Train Dialog-Fact Encoder\n", "\n", "**Goal:** Train an embedding model to match dialogs with (possibly) relevant facts " ] }, { "cell_type": "markdown", "id": "723a9f8a-800a-4de0-ab89-e4d984271a5b", "metadata": {}, "source": [ "### Constants" ] }, { "cell_type": "code", "execution_count": 1, "id": "7167d6e4-7a7f-4f7f-b4e7-92b9613afed8", "metadata": {}, "outputs": [], "source": [ "model_name = \"BAAI/bge-base-en-v1.5\"\n", "query_prefix = \"Represent this sentence for searching relevant passages: \"\n", "max_len = 512\n", "training_hn_file = \"./data/train.jsonl\"\n", "eval_file = \"./data/eval.jsonl\"\n", "batch_size = 1350\n", "output_model_path = \"./dfe-base-en\"\n", "hf_repo_name = \"julep-ai/dfe-base-en\"" ] }, { "cell_type": "markdown", "id": "22aad488-38c3-40b9-8e5b-6d47b41d49cf", "metadata": {}, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "98d5e97e-df3b-43e4-b82c-2f4768a217b6", "metadata": {}, "outputs": [], "source": [ "import itertools as it\n", "\n", "import graphviz\n", "import jsonlines as jsonl\n", "from lion_pytorch import Lion\n", "from sentence_transformers import InputExample, SentenceTransformer, losses as ls, models as ml, util\n", "from sentence_transformers.evaluation import SimilarityFunction, TripletEvaluator\n", "import torch\n", "from torch.utils.data import DataLoader, IterableDataset\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "markdown", "id": "72ee0c6c-2785-49ff-85ec-600b76af11b8", "metadata": {}, "source": [ "### Dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "b17def02-f756-4973-a29f-dd628da34e58", "metadata": {}, "outputs": [], "source": [ "def hn_output(file):\n", " with jsonl.open(file) as reader:\n", " for entry in reader:\n", " query = entry[\"query\"]\n", " pos = [dict(dialog=dialog) for dialog in entry[\"pos\"]]\n", " neg = [dict(dialog=dialog) for dialog in entry[\"neg\"]]\n", "\n", " for combined in it.product(\n", " [dict(fact=query)],\n", " pos,\n", " neg,\n", " ):\n", " yield InputExample(texts=list(combined))" ] }, { "cell_type": "code", "execution_count": 4, "id": "34649f83-5bc3-4b1b-a1b2-3d406b84979d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "01107f542dec483a9a48ed4b9e4b9a76", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "039f46c46d724fa0aac242492248dbff", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "training_data = list(tqdm(hn_output(training_hn_file)))\n", "eval_data = list(tqdm(hn_output(eval_file)))" ] }, { "cell_type": "code", "execution_count": 5, "id": "8e817f20-4e80-4842-bf45-f7439a5e2b7a", "metadata": {}, "outputs": [], "source": [ "dataloader = DataLoader(training_data, shuffle=True, batch_size=batch_size)\n", "eval_dataloader = DataLoader(eval_data, shuffle=True, batch_size=batch_size // 10)" ] }, { "cell_type": "markdown", "id": "be0a103c-1c3d-41fa-933c-f0b843087658", "metadata": {}, "source": [ "### DFE Model Architecture" ] }, { "cell_type": "code", "execution_count": 6, "id": "c8eea066-1f4e-4184-9215-0b5fdd1cdf16", "metadata": {}, "outputs": [], "source": [ "# Base model\n", "base_model = SentenceTransformer(model_name)" ] }, { "cell_type": "code", "execution_count": 7, "id": "7f31eda8-d224-4d30-8a6b-ed4cb32a2c12", "metadata": {}, "outputs": [], "source": [ "# Freeze base transformer layers\n", "for param in base_model.parameters():\n", " param.requires_grad = False" ] }, { "cell_type": "code", "execution_count": 8, "id": "721c3897-9ef0-409f-9e9d-a693975486bf", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:0\")\n", "\n", "# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset\n", "# the body location\n", "base_model._target_device = device\n", "base_model = base_model.to(device)" ] }, { "cell_type": "code", "execution_count": 9, "id": "6115d96b-fe35-4a23-9a21-f3da52304f3a", "metadata": {}, "outputs": [], "source": [ "emb_dims = base_model._first_module().get_word_embedding_dimension() # 768\n", "\n", "def dense_projector(dims: int):\n", " proj_dims = dims * 2 # 1536\n", " \n", " return [\n", " ml.Dense(dims, proj_dims), # 768 -> 1536\n", " ml.Dense(proj_dims, proj_dims), # 1536 -> 1536\n", " ml.Dropout(0.1),\n", " ml.Dense(proj_dims, proj_dims), # 1536 -> 1536\n", " ml.Dense(proj_dims, dims), # 1536 -> 768\n", " ]\n", "\n", "def asym_module(dims: int, keys: list[str], allow_empty_key: bool = False):\n", " return ml.Asym(\n", " {\n", " key: dense_projector(dims)\n", " for key in keys\n", " },\n", " allow_empty_key=allow_empty_key,\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "id": "2b273b52-b3b1-4f29-9d9a-1fe00d29c686", "metadata": {}, "outputs": [], "source": [ "base_model._modules[\"2\"] = asym_module(emb_dims, [\"dialog\", \"fact\"])" ] }, { "cell_type": "code", "execution_count": 11, "id": "03004002-b9d1-4b71-8ea5-bd2a2072c751", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('0',\n", " Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel ),\n", " ('1',\n", " Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})),\n", " ('2',\n", " Asym(\n", " (dialog-0): Dense({'in_features': 768, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n", " (dialog-1): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n", " (dialog-2): Dropout(\n", " (dropout_layer): Dropout(p=0.1, inplace=False)\n", " )\n", " (dialog-3): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n", " (dialog-4): Dense({'in_features': 1536, 'out_features': 768, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n", " (fact-0): Dense({'in_features': 768, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n", " (fact-1): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n", " (fact-2): Dropout(\n", " (dropout_layer): Dropout(p=0.1, inplace=False)\n", " )\n", " (fact-3): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n", " (fact-4): Dense({'in_features': 1536, 'out_features': 768, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n", " ))])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "base_model._modules" ] }, { "cell_type": "markdown", "id": "6ea33246-2612-443d-a5c0-4179eea1a126", "metadata": {}, "source": [ "### Prepare training loss and evaluator" ] }, { "cell_type": "code", "execution_count": 12, "id": "e0008a08-a08d-4523-b477-212083a93aa8", "metadata": {}, "outputs": [], "source": [ "train_loss = ls.TripletLoss(model=base_model)" ] }, { "cell_type": "code", "execution_count": 13, "id": "53b0aba9-a279-4c90-8949-e0096b5ed4c7", "metadata": {}, "outputs": [], "source": [ "triplet_evaluator = TripletEvaluator.from_input_examples(\n", " eval_data, # Triplet is ({dialog: }, {fact: }, [{fact: }])\n", " batch_size=batch_size // 10,\n", " main_distance_function=SimilarityFunction.COSINE,\n", " show_progress_bar=True,\n", " write_csv=True,\n", ")" ] }, { "cell_type": "markdown", "id": "a6ea59f8-c1e1-404b-ba84-95c8199cd1df", "metadata": {}, "source": [ "### Train model" ] }, { "cell_type": "code", "execution_count": null, "id": "dbf3b8c9-8ef8-4198-b284-910c57f2cbca", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ea0ed014f83b4651b810c0abd317add9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/15 [00:00