hengyu commited on
Commit
a60b021
·
1 Parent(s): 191c29f

add script

Browse files

Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com>

Files changed (1) hide show
  1. evaluation.ipynb +159 -0
evaluation.ipynb ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Introduction\n",
9
+ "\n",
10
+ "This tutorial demonstrates how to perform evaluation on a gpt-j-6B-int8 model."
11
+ ]
12
+ },
13
+ {
14
+ "attachments": {},
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "## Prerequisite"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {
25
+ "vscode": {
26
+ "languageId": "plaintext"
27
+ }
28
+ },
29
+ "outputs": [],
30
+ "source": [
31
+ "!pip install onnx onnxruntime torch transformers datasets accelerate"
32
+ ]
33
+ },
34
+ {
35
+ "attachments": {},
36
+ "cell_type": "markdown",
37
+ "metadata": {},
38
+ "source": [
39
+ "## Run\n",
40
+ "\n",
41
+ "### 1. Get lambada acc"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {
48
+ "vscode": {
49
+ "languageId": "plaintext"
50
+ }
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "from transformers import AutoTokenizer\n",
55
+ "import torch\n",
56
+ "from datasets import load_dataset\n",
57
+ "import onnxruntime as ort\n",
58
+ "from torch.nn.functional import pad\n",
59
+ "\n",
60
+ "# load model\n",
61
+ "model_id = \"EleutherAI/gpt-j-6B\"\n",
62
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
63
+ "\n",
64
+ "def tokenize_function(examples):\n",
65
+ " example = tokenizer(examples['text'])\n",
66
+ " return example\n",
67
+ "\n",
68
+ "# create dataset\n",
69
+ "dataset = load_dataset('lambada', split='validation')\n",
70
+ "dataset = dataset.shuffle(seed=42)\n",
71
+ "dataset = dataset.map(tokenize_function, batched=True)\n",
72
+ "dataset.set_format(type='torch', columns=['input_ids'])\n",
73
+ "\n",
74
+ "# create session\n",
75
+ "options = ort.SessionOptions()\n",
76
+ "options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n",
77
+ "session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())\n",
78
+ "total, hit = 0, 0\n",
79
+ "index = 1\n",
80
+ "\n",
81
+ "# inference\n",
82
+ "for idx, batch in enumerate(dataset):\n",
83
+ " input_ids = batch['input_ids'].unsqueeze(0)\n",
84
+ " label = input_ids[:, -1]\n",
85
+ " pad_len = 0 ##set to 0\n",
86
+ " input_ids = pad(input_ids, (0, pad_len), value=1)\n",
87
+ " ort_inputs = {\n",
88
+ " 'input_ids': input_ids.detach().cpu().numpy(),\n",
89
+ " 'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')\n",
90
+ " }\n",
91
+ " predictions = session.run(None, ort_inputs)\n",
92
+ " outputs = torch.from_numpy(predictions[0]) \n",
93
+ " last_token_logits = outputs[:, -2 - pad_len, :]\n",
94
+ " pred = last_token_logits.argmax(dim=-1)\n",
95
+ " total += label.size(0)\n",
96
+ " hit += (pred == label).sum().item()\n",
97
+ "acc = hit / total\n",
98
+ "print('acc: ', acc)"
99
+ ]
100
+ },
101
+ {
102
+ "attachments": {},
103
+ "cell_type": "markdown",
104
+ "metadata": {},
105
+ "source": [
106
+ "### 2. Text Generation"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {
113
+ "vscode": {
114
+ "languageId": "plaintext"
115
+ }
116
+ },
117
+ "outputs": [],
118
+ "source": [
119
+ "import os\n",
120
+ "import time\n",
121
+ "import sys\n",
122
+ "\n",
123
+ "# create session\n",
124
+ "sess_options = ort.SessionOptions()\n",
125
+ "sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n",
126
+ "session = ort.InferenceSession('/path/to/model.onnx', sess_options)\n",
127
+ "\n",
128
+ "# input prompt\n",
129
+ "# 32 tokens input\n",
130
+ "prompt = \"Once upon a time, there existed a little girl, who liked to have adventures.\" + \\\n",
131
+ " \" She wanted to go to places and meet new people, and have fun.\"\n",
132
+ "\n",
133
+ "print(\"prompt: \", prompt)\n",
134
+ "\n",
135
+ "# start\n",
136
+ "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
137
+ "for i in range(32):\n",
138
+ " inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
139
+ " 'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')}\n",
140
+ " output = session.run(None, inp)\n",
141
+ " logits = output[0]\n",
142
+ " logits = torch.from_numpy(logits)\n",
143
+ " next_token_logits = logits[:, -1, :]\n",
144
+ " probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n",
145
+ " next_tokens = torch.argmax(probs, dim=-1)\n",
146
+ " input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
147
+ "print(tokenizer.decode(input_ids[0]))"
148
+ ]
149
+ }
150
+ ],
151
+ "metadata": {
152
+ "language_info": {
153
+ "name": "python"
154
+ },
155
+ "orig_nbformat": 4
156
+ },
157
+ "nbformat": 4,
158
+ "nbformat_minor": 2
159
+ }