mychen76 commited on
Commit
6ef314e
·
verified ·
1 Parent(s): bf05b20

Upload finetune-llama3-using-qlora-embed.ipynb

Browse files
finetune-llama3-using-qlora-embed.ipynb CHANGED
@@ -900,19 +900,6 @@
900
  "torch.cuda.empty_cache()"
901
  ]
902
  },
903
- {
904
- "cell_type": "code",
905
- "execution_count": null,
906
- "id": "b7522c18-31dc-4ed2-a205-0ac080c4b59b",
907
- "metadata": {},
908
- "outputs": [],
909
- "source": [
910
- "# from transformers import TextStreamer\n",
911
- "# text_streamer = TextStreamer(tokenizer, skip_prompt = True)\n",
912
- "# _ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128,\n",
913
- "# use_cache = True, temperature = 1.5, min_p = 0.1)"
914
- ]
915
- },
916
  {
917
  "cell_type": "markdown",
918
  "id": "841a48fe",
@@ -927,7 +914,7 @@
927
  "tags": []
928
  },
929
  "source": [
930
- "### Prepare the dataset \n",
931
  "\n",
932
  "We will use 10K rows from the `ultrachat_200k` database."
933
  ]
@@ -1343,10 +1330,67 @@
1343
  "task_input = format_task_input(task_type, task_context)\n",
1344
  "\n",
1345
  "messages = [{\"role\": \"user\", \"content\": task_input},]\n",
1346
- "inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
 
1347
  "generate_response(inputs)\n"
1348
  ]
1349
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1350
  {
1351
  "cell_type": "code",
1352
  "execution_count": null,
 
900
  "torch.cuda.empty_cache()"
901
  ]
902
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
903
  {
904
  "cell_type": "markdown",
905
  "id": "841a48fe",
 
914
  "tags": []
915
  },
916
  "source": [
917
+ "##### Prepare the dataset \n",
918
  "\n",
919
  "We will use 10K rows from the `ultrachat_200k` database."
920
  ]
 
1330
  "task_input = format_task_input(task_type, task_context)\n",
1331
  "\n",
1332
  "messages = [{\"role\": \"user\", \"content\": task_input},]\n",
1333
+ "inputs = tokenizer.apply_chat_template(messages, tokenize=False, \n",
1334
+ " add_generation_prompt=True)\n",
1335
  "generate_response(inputs)\n"
1336
  ]
1337
  },
1338
+ {
1339
+ "cell_type": "markdown",
1340
+ "id": "1962fbcc-5d11-4cc3-9abe-0d78d8d780e8",
1341
+ "metadata": {},
1342
+ "source": [
1343
+ "#### Test Sample-3 Streaming"
1344
+ ]
1345
+ },
1346
+ {
1347
+ "cell_type": "code",
1348
+ "execution_count": 28,
1349
+ "id": "a9183096-5882-4595-b872-911a51557703",
1350
+ "metadata": {
1351
+ "execution": {
1352
+ "iopub.execute_input": "2024-09-21T10:54:14.949160Z",
1353
+ "iopub.status.busy": "2024-09-21T10:54:14.948971Z",
1354
+ "iopub.status.idle": "2024-09-21T10:54:19.811260Z",
1355
+ "shell.execute_reply": "2024-09-21T10:54:19.810881Z",
1356
+ "shell.execute_reply.started": "2024-09-21T10:54:14.949147Z"
1357
+ }
1358
+ },
1359
+ "outputs": [
1360
+ {
1361
+ "name": "stdout",
1362
+ "output_type": "stream",
1363
+ "text": [
1364
+ "<|tasktype|>\n",
1365
+ "extractive question answering\n",
1366
+ "<|taskinput|>\n",
1367
+ "{{context}}\n",
1368
+ "\n",
1369
+ "Q: How would you decide whether to keep the same set of chat tokens when training further models? \n",
1370
+ "\n",
1371
+ "Context:\n",
1372
+ "When setting the template for a model that’s already been trained for chat, you should ensure that the template exactly matches the message formatting that the model saw during training, or else you will probably experience performance degradation. This is true even if you’re training the model further - you will probably get the best performance if you keep the chat tokens constant. \n",
1373
+ "\n",
1374
+ "What is the reason behind the statement that the model will probably experience performance degradation?\n",
1375
+ "<|taskoutput|>\n",
1376
+ "Tokenization\n",
1377
+ "<|eot_id|>\n"
1378
+ ]
1379
+ }
1380
+ ],
1381
+ "source": [
1382
+ "from transformers import TextStreamer\n",
1383
+ "text_streamer = TextStreamer(tokenizer, skip_prompt = True)\n",
1384
+ "\n",
1385
+ "prompt = tokenizer(inputs, return_tensors=\"pt\").to('cuda')\n",
1386
+ "\n",
1387
+ "_ = model.generate(input_ids = prompt['input_ids'], \n",
1388
+ " streamer = text_streamer, \n",
1389
+ " max_new_tokens = 128,\n",
1390
+ " use_cache = True, temperature = 1.5, min_p = 0.1, \n",
1391
+ " eos_token_id=terminators)\n"
1392
+ ]
1393
+ },
1394
  {
1395
  "cell_type": "code",
1396
  "execution_count": null,