Moses25 commited on
Commit
4ba4c08
·
1 Parent(s): a3ee9d8

add gradio file

Browse files
Files changed (2) hide show
  1. attn_and_long_ctx_patches.py +223 -0
  2. gradio_demo.py +626 -0
attn_and_long_ctx_patches.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional, Tuple, Union
4
+ import transformers
5
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
6
+ import math
7
+
8
+ try:
9
+ from xformers import ops as xops
10
+ except ImportError:
11
+ xops = None
12
+ print(
13
+ "Xformers is not installed correctly. If you want to use memory_efficient_attention use the following command to install Xformers\npip install xformers."
14
+ )
15
+
16
+
17
+ STORE_KV_BEFORE_ROPE = False
18
+ USE_MEM_EFF_ATTENTION = False
19
+ ALPHA = 1.0
20
+ AUTO_COEFF = 1.0
21
+ SCALING_FACTOR = None
22
+
23
+
24
+ def apply_rotary_pos_emb_single(q, cos, sin, position_ids):
25
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
26
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
27
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
28
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
29
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
30
+ q_embed = (q * cos) + (rotate_half(q) * sin)
31
+ return q_embed
32
+
33
+
34
+ def xformers_forward(
35
+ self,
36
+ hidden_states: torch.Tensor,
37
+ attention_mask: Optional[torch.Tensor] = None,
38
+ position_ids: Optional[torch.LongTensor] = None,
39
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
40
+ output_attentions: bool = False,
41
+ use_cache: bool = False,
42
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
43
+ bsz, q_len, _ = hidden_states.size()
44
+
45
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
46
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
47
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
48
+
49
+ kv_seq_len = key_states.shape[-2]
50
+ if past_key_value is not None:
51
+ kv_seq_len += past_key_value[0].shape[-2]
52
+
53
+ if STORE_KV_BEFORE_ROPE is False:
54
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
56
+ # [bsz, nh, t, hd]
57
+
58
+ if past_key_value is not None:
59
+ # reuse k, v, self_attention
60
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
61
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
62
+
63
+ past_key_value = (key_states, value_states) if use_cache else None
64
+ else:
65
+ if past_key_value is not None:
66
+ # reuse k, v, self_attention
67
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
68
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
69
+ past_key_value = (key_states, value_states) if use_cache else None
70
+
71
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
72
+
73
+ query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
74
+ position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=cos.device)
75
+ position_ids = position_ids.unsqueeze(0).view(-1, kv_seq_len)
76
+ key_states = apply_rotary_pos_emb_single(key_states, cos, sin, position_ids)
77
+
78
+ if xops is not None and USE_MEM_EFF_ATTENTION:
79
+ attn_weights = None
80
+ query_states = query_states.transpose(1, 2)
81
+ key_states = key_states.transpose(1, 2)
82
+ value_states = value_states.transpose(1, 2)
83
+ attn_bias = None if (query_states.size(1)==1 and key_states.size(1)>1) else xops.LowerTriangularMask()
84
+ attn_output = xops.memory_efficient_attention(
85
+ query_states, key_states, value_states, attn_bias=attn_bias, p=0)
86
+ else:
87
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
88
+
89
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
90
+ raise ValueError(
91
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
92
+ f" {attn_weights.size()}"
93
+ )
94
+
95
+ if attention_mask is not None:
96
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
97
+ raise ValueError(
98
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
99
+ )
100
+ attn_weights = attn_weights + attention_mask
101
+ attn_weights = torch.max(
102
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
103
+ )
104
+
105
+ # upcast attention to fp32
106
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
107
+ attn_output = torch.matmul(attn_weights, value_states)
108
+
109
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
110
+ raise ValueError(
111
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
112
+ f" {attn_output.size()}"
113
+ )
114
+
115
+ attn_output = attn_output.transpose(1, 2)
116
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
117
+
118
+ attn_output = self.o_proj(attn_output)
119
+
120
+ if not output_attentions:
121
+ attn_weights = None
122
+
123
+ return attn_output, attn_weights, past_key_value
124
+
125
+
126
+ old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
127
+
128
+
129
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
130
+ self.max_seq_len_cached = seq_len
131
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
132
+ t = t / self.scaling_factor
133
+
134
+ freqs = torch.einsum("i,j->ij", t, self.ntk_inv_freq.to(device))
135
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
136
+ emb = torch.cat((freqs, freqs), dim=-1)
137
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
138
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
139
+
140
+
141
+ def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=None):
142
+ self.alpha = ALPHA
143
+ if SCALING_FACTOR is None:
144
+ self.scaling_factor = scaling_factor or 1.0
145
+ else:
146
+ self.scaling_factor = SCALING_FACTOR
147
+ if isinstance(ALPHA,(float,int)):
148
+ base = base * ALPHA ** (dim / (dim-2))
149
+ self.base = base
150
+ elif ALPHA=='auto':
151
+ self.base = base
152
+ else:
153
+ raise ValueError(ALPHA)
154
+ old_init(self, dim, max_position_embeddings, base, device)
155
+ self.ntk_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
156
+
157
+ self._set_cos_sin_cache = _set_cos_sin_cache
158
+ self._set_cos_sin_cache(
159
+ self, seq_len=max_position_embeddings, device=self.ntk_inv_freq.device, dtype=torch.get_default_dtype()
160
+ )
161
+
162
+
163
+ def adaptive_ntk_forward(self, x, seq_len=None):
164
+ if seq_len > self.max_seq_len_cached:
165
+ if isinstance(self.alpha,(float,int)):
166
+ self._set_cos_sin_cache(self, seq_len=seq_len, device=x.device, dtype=x.dtype)
167
+ elif self.alpha=='auto':
168
+ t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
169
+ t = t / self.scaling_factor
170
+ dim = self.dim
171
+ alpha = (seq_len / (self.max_position_embeddings/2) - 1) * AUTO_COEFF
172
+ base = self.base * alpha ** (dim / (dim-2))
173
+ ntk_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))
174
+
175
+ freqs = torch.einsum("i,j->ij", t, ntk_inv_freq)
176
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
177
+ cos_cached = emb.cos()[None, None, :, :]
178
+ sin_cached = emb.sin()[None, None, :, :]
179
+ return (
180
+ cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
181
+ sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
182
+ )
183
+ return (
184
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
185
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
186
+ )
187
+
188
+
189
+ def apply_attention_patch(
190
+ use_memory_efficient_attention=False,
191
+ store_kv_before_rope=False
192
+ ):
193
+ global USE_MEM_EFF_ATTENTION, STORE_KV_BEFORE_ROPE
194
+ if use_memory_efficient_attention is True and xops is not None:
195
+ USE_MEM_EFF_ATTENTION = use_memory_efficient_attention
196
+ print("USE_MEM_EFF_ATTENTION: ",USE_MEM_EFF_ATTENTION)
197
+ STORE_KV_BEFORE_ROPE = store_kv_before_rope
198
+ print("STORE_KV_BEFORE_ROPE:", STORE_KV_BEFORE_ROPE)
199
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
200
+
201
+
202
+ def apply_ntk_scaling_patch(alpha: Union[float,str], scaling_factor: Optional[float] = None):
203
+ global ALPHA
204
+ global SCALING_FACTOR
205
+ ALPHA = alpha
206
+ SCALING_FACTOR = scaling_factor
207
+ try:
208
+ ALPHA = float(ALPHA)
209
+ except ValueError:
210
+ if ALPHA!="auto":
211
+ raise ValueError(f"Alpha can only be a float or 'auto', but given {ALPHA}")
212
+ print(f"Apply NTK scaling with ALPHA={ALPHA}")
213
+ if scaling_factor is None:
214
+ print(f"The value of scaling factor will be read from model config file, or set to 1.")
215
+ else:
216
+ print(f"Warning: scaling factor is set to {SCALING_FACTOR}. \
217
+ If you set the value by hand, do not forget to update \
218
+ max_position_embeddings in the model config file.")
219
+
220
+ transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init
221
+ if hasattr(transformers.models.llama.modeling_llama,'LlamaLinearScalingRotaryEmbedding'):
222
+ transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ = adaptive_ntk_init
223
+ transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
gradio_demo.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ LlamaForCausalLM,
4
+ LlamaTokenizer,
5
+ StoppingCriteria,
6
+ BitsAndBytesConfig
7
+ )
8
+ import gradio as gr
9
+ import argparse
10
+ import os
11
+ from queue import Queue
12
+ from threading import Thread
13
+ import traceback
14
+ import gc
15
+ import json
16
+ import requests
17
+ from typing import Iterable, List
18
+ import subprocess
19
+ import re
20
+
21
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant.Help as much as you can."""
22
+
23
+ TEMPLATE_WITH_SYSTEM_PROMPT = (
24
+ "[INST] <<SYS>>\n"
25
+ "{system_prompt}\n"
26
+ "<</SYS>>\n\n"
27
+ "{instruction} [/INST]"
28
+ )
29
+
30
+ TEMPLATE_WITHOUT_SYSTEM_PROMPT = "[INST] {instruction} [/INST]"
31
+
32
+ # Parse command-line arguments
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ '--base_model',
36
+ default=None,
37
+ type=str,
38
+ required=True,
39
+ help='Base model path')
40
+ parser.add_argument('--lora_model', default=None, type=str,
41
+ help="If None, perform inference on the base model")
42
+ parser.add_argument(
43
+ '--tokenizer_path',
44
+ default=None,
45
+ type=str,
46
+ help='If None, lora model path or base model path will be used')
47
+ parser.add_argument(
48
+ '--gpus',
49
+ default="0",
50
+ type=str,
51
+ help='If None, cuda:0 will be used. Inference using multi-cards: --gpus=0,1,... ')
52
+ parser.add_argument('--share', default=True, help='Share gradio domain name')
53
+ parser.add_argument('--port', default=19324, type=int, help='Port of gradio demo')
54
+ parser.add_argument(
55
+ '--max_memory',
56
+ default=1024,
57
+ type=int,
58
+ help='Maximum number of input tokens (including system prompt) to keep. If exceeded, earlier history will be discarded.')
59
+ parser.add_argument(
60
+ '--load_in_8bit',
61
+ action='store_true',
62
+ default=False,
63
+ help='Use 8 bit quantized model')
64
+ parser.add_argument(
65
+ '--load_in_4bit',
66
+ action='store_true',
67
+ default=False,
68
+ help='Use 4 bit quantized model')
69
+ parser.add_argument(
70
+ '--only_cpu',
71
+ action='store_true',
72
+ help='Only use CPU for inference')
73
+ parser.add_argument(
74
+ '--alpha',
75
+ type=str,
76
+ default="1.0",
77
+ help="The scaling factor of NTK method, can be a float or 'auto'. ")
78
+ parser.add_argument(
79
+ "--use_vllm",
80
+ action='store_true',
81
+ help="Use vLLM as back-end LLM service.")
82
+ parser.add_argument(
83
+ "--post_host",
84
+ type=str,
85
+ default="0.0.0.0",
86
+ help="Host of vLLM service.")
87
+ parser.add_argument(
88
+ "--post_port",
89
+ type=int,
90
+ default=7777,
91
+ help="Port of vLLM service.")
92
+ args = parser.parse_args()
93
+
94
+ ENABLE_CFG_SAMPLING = True
95
+ try:
96
+ from transformers.generation import UnbatchedClassifierFreeGuidanceLogitsProcessor
97
+ except ImportError:
98
+ ENABLE_CFG_SAMPLING = False
99
+ print("Install the latest transformers (commit equal or later than d533465) to enable CFG sampling.")
100
+ if args.use_vllm is True:
101
+ print("CFG sampling is disabled when using vLLM.")
102
+ ENABLE_CFG_SAMPLING = False
103
+
104
+ if args.only_cpu is True:
105
+ args.gpus = ""
106
+ if args.load_in_8bit or args.load_in_4bit:
107
+ raise ValueError("Quantization is unavailable on CPU.")
108
+ if args.load_in_8bit and args.load_in_4bit:
109
+ raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments")
110
+ import sys
111
+ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
112
+ sys.path.append(parent_dir)
113
+ from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch
114
+ if not args.only_cpu:
115
+ apply_attention_patch(use_memory_efficient_attention=True)
116
+ apply_ntk_scaling_patch(args.alpha)
117
+
118
+ # Set CUDA devices if available
119
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
120
+
121
+
122
+ # Peft library can only import after setting CUDA devices
123
+ from peft import PeftModel
124
+
125
+
126
+ # Set up the required components: model and tokenizer
127
+
128
+ def setup():
129
+ global tokenizer, model, device, share, port, max_memory
130
+ if args.use_vllm:
131
+ # global share, port, max_memory
132
+ max_memory = args.max_memory
133
+ port = args.port
134
+ share = args.share
135
+
136
+ if args.lora_model is not None:
137
+ raise ValueError("vLLM currently does not support LoRA, please merge the LoRA weights to the base model.")
138
+ if args.load_in_8bit or args.load_in_4bit:
139
+ raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.")
140
+ if args.only_cpu:
141
+ raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.")
142
+
143
+ if args.tokenizer_path is None:
144
+ args.tokenizer_path = args.base_model
145
+ tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
146
+
147
+ print("Start launch vllm server.")
148
+ cmd = f"python -m vllm.entrypoints.api_server \
149
+ --model={args.base_model} \
150
+ --tokenizer={args.tokenizer_path} \
151
+ --tokenizer-mode=slow \
152
+ --tensor-parallel-size={len(args.gpus.split(','))} \
153
+ --host {args.post_host} \
154
+ --port {args.post_port} \
155
+ &"
156
+ subprocess.check_call(cmd, shell=True)
157
+ else:
158
+ max_memory = args.max_memory
159
+ port = args.port
160
+ share = args.share
161
+ load_type = torch.float16
162
+ if torch.cuda.is_available():
163
+ device = torch.device(0)
164
+ else:
165
+ device = torch.device('cpu')
166
+ if args.tokenizer_path is None:
167
+ args.tokenizer_path = args.base_model
168
+ # if args.lora_model is None:
169
+ # args.tokenizer_path = args.base_model
170
+ tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
171
+ tokenizer.pad_token_id = 0
172
+ # tokenizer.pad_token = "<>"
173
+ base_model = LlamaForCausalLM.from_pretrained(
174
+ args.base_model,
175
+ torch_dtype=load_type,
176
+ low_cpu_mem_usage=True,
177
+ device_map='auto',
178
+ quantization_config=BitsAndBytesConfig(
179
+ load_in_4bit=args.load_in_4bit,
180
+ load_in_8bit=args.load_in_8bit,
181
+ bnb_4bit_compute_dtype=load_type,
182
+ # load_in_8bit_fp32_cpu_offload=True
183
+ )
184
+ )
185
+
186
+ model_vocab_size = base_model.get_input_embeddings().weight.size(0)
187
+ tokenizer_vocab_size = len(tokenizer)
188
+ print(f"Vocab of the base model: {model_vocab_size}")
189
+ print(f"Vocab of the tokenizer: {tokenizer_vocab_size}")
190
+ if model_vocab_size != tokenizer_vocab_size:
191
+ print("Resize model embeddings to fit tokenizer")
192
+ base_model.resize_token_embeddings(tokenizer_vocab_size)
193
+ if args.lora_model is not None:
194
+ print("loading peft model")
195
+ model = PeftModel.from_pretrained(
196
+ base_model,
197
+ args.lora_model,
198
+ torch_dtype=load_type,
199
+ device_map='auto',
200
+ ).half()
201
+ else:
202
+ model = base_model
203
+
204
+ if device == torch.device('cpu'):
205
+ model.float()
206
+
207
+ model.eval()
208
+
209
+
210
+ # Reset the user input
211
+ def reset_user_input():
212
+ return gr.update(value='')
213
+
214
+
215
+ # Reset the state
216
+ def reset_state():
217
+ return []
218
+
219
+
220
+ def generate_prompt(instruction, response="", with_system_prompt=True, system_prompt=DEFAULT_SYSTEM_PROMPT):
221
+ if with_system_prompt is True:
222
+ prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map({'instruction': instruction,'system_prompt': system_prompt})
223
+ else:
224
+ prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({'instruction': instruction})
225
+ if len(response)>0:
226
+ prompt += " " + response
227
+ return prompt
228
+
229
+
230
+ # User interaction function for chat
231
+ def user(user_message, history):
232
+ return gr.update(value="", interactive=False), history + \
233
+ [[user_message, None]]
234
+
235
+
236
+ class Stream(StoppingCriteria):
237
+ def __init__(self, callback_func=None):
238
+ self.callback_func = callback_func
239
+
240
+ def __call__(self, input_ids, scores) -> bool:
241
+ if self.callback_func is not None:
242
+ self.callback_func(input_ids[0])
243
+ return False
244
+
245
+
246
+ class Iteratorize:
247
+ """
248
+ Transforms a function that takes a callback
249
+ into a lazy iterator (generator).
250
+
251
+ Adapted from: https://stackoverflow.com/a/9969000
252
+ """
253
+ def __init__(self, func, kwargs=None, callback=None):
254
+ self.mfunc = func
255
+ self.c_callback = callback
256
+ self.q = Queue()
257
+ self.sentinel = object()
258
+ self.kwargs = kwargs or {}
259
+ self.stop_now = False
260
+
261
+ def _callback(val):
262
+ if self.stop_now:
263
+ raise ValueError
264
+ self.q.put(val)
265
+
266
+ def gentask():
267
+ try:
268
+ ret = self.mfunc(callback=_callback, **self.kwargs)
269
+ except ValueError:
270
+ pass
271
+ except Exception:
272
+ traceback.print_exc()
273
+
274
+ clear_torch_cache()
275
+ self.q.put(self.sentinel)
276
+ if self.c_callback:
277
+ self.c_callback(ret)
278
+
279
+ self.thread = Thread(target=gentask)
280
+ self.thread.start()
281
+
282
+ def __iter__(self):
283
+ return self
284
+
285
+ def __next__(self):
286
+ obj = self.q.get(True, None)
287
+ if obj is self.sentinel:
288
+ raise StopIteration
289
+ else:
290
+ return obj
291
+
292
+ def __del__(self):
293
+ clear_torch_cache()
294
+
295
+ def __enter__(self):
296
+ return self
297
+
298
+ def __exit__(self, exc_type, exc_val, exc_tb):
299
+ self.stop_now = True
300
+ clear_torch_cache()
301
+
302
+
303
+ def clear_torch_cache():
304
+ gc.collect()
305
+ if torch.cuda.device_count() > 0:
306
+ torch.cuda.empty_cache()
307
+
308
+
309
+ def post_http_request(prompt: str,
310
+ api_url: str,
311
+ n: int = 1,
312
+ top_p: float = 0.9,
313
+ top_k: int = 40,
314
+ temperature: float = 0.2,
315
+ max_tokens: int = 1024,
316
+ presence_penalty: float = 1.0,
317
+ use_beam_search: bool = False,
318
+ stream: bool = False) -> requests.Response:
319
+ headers = {"User-Agent": "Test Client"}
320
+ pload = {
321
+ "prompt": prompt,
322
+ "n": n,
323
+ "top_p": 1 if use_beam_search else top_p,
324
+ "top_k": -1 if use_beam_search else top_k,
325
+ "temperature": 0 if use_beam_search else temperature,
326
+ "max_tokens": max_tokens,
327
+ "use_beam_search": use_beam_search,
328
+ "best_of": 5 if use_beam_search else n,
329
+ "presence_penalty": presence_penalty,
330
+ "stream": stream,
331
+ }
332
+ print(pload)
333
+
334
+ response = requests.post(api_url, headers=headers, json=pload, stream=True)
335
+ return response
336
+
337
+
338
+ def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
339
+ for chunk in response.iter_lines(chunk_size=8192,
340
+ decode_unicode=False,
341
+ delimiter=b"\0"):
342
+ if chunk:
343
+ data = json.loads(chunk.decode("utf-8"))
344
+ output = data["text"]
345
+ yield output
346
+
347
+
348
+ # Perform prediction based on the user input and history
349
+ @torch.no_grad()
350
+ def predict(
351
+ history,
352
+ system_prompt,
353
+ negative_prompt,
354
+ max_new_tokens=1024,
355
+ top_p=0.89,
356
+ temperature=0.85,
357
+ top_k=40,
358
+ do_sample=True,
359
+ repetition_penalty=1.2,
360
+ guidance_scale=1.0,
361
+ presence_penalty=0.0,
362
+ ):
363
+ if len(system_prompt) == 0:
364
+ system_prompt = DEFAULT_SYSTEM_PROMPT
365
+ while True:
366
+ print("len(history):", len(history))
367
+ print("history: ", history)
368
+ history[-1][1] = ""
369
+ if len(history) == 1:
370
+ input = history[0][0]
371
+ prompt = generate_prompt(input,response="", with_system_prompt=True, system_prompt=system_prompt)
372
+ print(f"prompt:{prompt}")
373
+ else:
374
+ input = history[0][0]
375
+ response = history[0][1]
376
+ prompt = generate_prompt(input, response=response, with_system_prompt=True, system_prompt=system_prompt)+'</s>'
377
+ for hist in history[1:-1]:
378
+ input = hist[0]
379
+ response = hist[1]
380
+ prompt = prompt + '<s>'+generate_prompt(input, response=response, with_system_prompt=False)+'</s>'
381
+ input = history[-1][0]
382
+ check_text = input.replace("<br>","").replace(" ","").replace("\n","")
383
+ if len(check_text) == 0:
384
+ input = ""
385
+ prompt = prompt + '<s>'+generate_prompt(input, response="", with_system_prompt=False)
386
+ print(f"prompt1:{prompt}")
387
+ input_length = len(tokenizer.encode(prompt, add_special_tokens=True))
388
+ print(f"Input length: {input_length}")
389
+ if input_length > max_memory and len(history) > 1:
390
+ print(f"The input length ({input_length}) exceeds the max memory ({max_memory}). The earlier history will be discarded.")
391
+ history = history[1:]
392
+ print("history: ", history)
393
+ else:
394
+ break
395
+
396
+ if args.use_vllm:
397
+ generate_params = {
398
+ 'max_tokens': max_new_tokens,
399
+ 'top_p': top_p,
400
+ 'temperature': temperature,
401
+ 'top_k': top_k,
402
+ "use_beam_search": not do_sample,
403
+ 'presence_penalty': presence_penalty,
404
+ }
405
+
406
+ api_url = f"http://{args.post_host}:{args.post_port}/generate"
407
+
408
+
409
+ response = post_http_request(prompt, api_url, **generate_params, stream=True)
410
+
411
+ for h in get_streaming_response(response):
412
+ for line in h:
413
+ line = line.replace(prompt, '')
414
+ history[-1][1] = line
415
+ yield history
416
+
417
+ else:
418
+ negative_text = None
419
+ if len(negative_prompt) != 0:
420
+ negative_text = re.sub(r"<<SYS>>\n(.*)\n<</SYS>>", f"<<SYS>>\n{negative_prompt}\n<</SYS>>", prompt)
421
+ inputs = tokenizer(prompt, return_tensors="pt")
422
+ input_ids = inputs["input_ids"].to(device)
423
+ if negative_text is None:
424
+ negative_prompt_ids = None
425
+ negative_prompt_attention_mask = None
426
+ else:
427
+ negative_inputs = tokenizer(negative_text,return_tensors="pt")
428
+ negative_prompt_ids = negative_inputs["input_ids"].to(device)
429
+ negative_prompt_attention_mask = negative_inputs["attention_mask"].to(device)
430
+ generate_params = {
431
+ 'input_ids': input_ids,
432
+ 'max_new_tokens': max_new_tokens,
433
+ 'top_p': top_p,
434
+ 'temperature': temperature,
435
+ 'top_k': top_k,
436
+ 'do_sample': do_sample,
437
+ 'repetition_penalty': repetition_penalty,
438
+ }
439
+ if ENABLE_CFG_SAMPLING is True:
440
+ generate_params['guidance_scale'] = guidance_scale
441
+ generate_params['negative_prompt_ids'] = negative_prompt_ids
442
+ generate_params['negative_prompt_attention_mask'] = negative_prompt_attention_mask
443
+
444
+ def generate_with_callback(callback=None, **kwargs):
445
+ if 'stopping_criteria' in kwargs:
446
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
447
+ else:
448
+ kwargs['stopping_criteria'] = [Stream(callback_func=callback)]
449
+ clear_torch_cache()
450
+ with torch.no_grad():
451
+ model.generate(**kwargs)
452
+
453
+ def generate_with_streaming(**kwargs):
454
+ return Iteratorize(generate_with_callback, kwargs, callback=None)
455
+
456
+ with generate_with_streaming(**generate_params) as generator:
457
+ for output in generator:
458
+ next_token_ids = output[len(input_ids[0]):]
459
+ if next_token_ids[0] in [tokenizer.eos_token_id,0]:
460
+ break
461
+ new_tokens = tokenizer.decode(
462
+ next_token_ids, skip_special_tokens=True)
463
+ if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0:
464
+ if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'):
465
+ new_tokens = ' ' + new_tokens
466
+
467
+ history[-1][1] = new_tokens
468
+ yield history
469
+ if len(next_token_ids) >= max_new_tokens:
470
+ break
471
+
472
+
473
+ # Call the setup function to initialize the components
474
+ setup()
475
+
476
+
477
+ # Create the Gradio interface
478
+ with gr.Blocks(
479
+ theme=gr.themes.Soft(),
480
+ css=".disclaimer {font-variant-caps: all-small-caps;}") as demo:
481
+ github_banner_path = 'https://raw.githubusercontent.com/moseshu/llama2-chat/main/llama2.jpg'
482
+ gr.HTML(f'<p align="center"><a href="https://huggingface.co/Moses25/Llama2-Moses-7b-chat"><img src={github_banner_path} width="100" height="40"/>Llama2-Moses-7b</a></p>')
483
+ chatbot = gr.Chatbot().style(height=300)
484
+ with gr.Row():
485
+ with gr.Column(scale=4):
486
+ with gr.Column(scale=3):
487
+ system_prompt_input = gr.Textbox(
488
+ show_label=True,
489
+ label="system prompt(仅在对话开始前或清空历史后修改有效,对话过程中修改无效)",
490
+ placeholder=DEFAULT_SYSTEM_PROMPT,
491
+ lines=1).style(
492
+ container=True)
493
+ negative_prompt_input = gr.Textbox(
494
+ show_label=True,
495
+ label="反向提示语(仅在对话开始前或清空历史后修改有效,对话过程中修改无效)",
496
+ placeholder="option",
497
+ lines=1,
498
+ visible=ENABLE_CFG_SAMPLING).style(
499
+ container=True)
500
+ with gr.Column(scale=10):
501
+ user_input = gr.Textbox(
502
+ show_label=True,
503
+ label="ChatBox",
504
+ placeholder="Shift + Enter发送消息...",
505
+ lines=10).style(
506
+ container=True)
507
+ with gr.Column(min_width=24, scale=1):
508
+ with gr.Row():
509
+ stop = gr.Button("Stop",variant='stop')
510
+ submitBtn = gr.Button("Submit", variant="primary")
511
+ with gr.Column(scale=1):
512
+ emptyBtn = gr.Button("Clear History")
513
+ max_new_token = gr.Slider(
514
+ 0,
515
+ 4096,
516
+ value=1024,
517
+ step=1.0,
518
+ label="Maximum New Token Length",
519
+ interactive=True)
520
+ top_p = gr.Slider(0, 1, value=0.9, step=0.01,
521
+ label="Top P", interactive=True)
522
+ temperature = gr.Slider(
523
+ 0,
524
+ 1,
525
+ value=0.7,
526
+ step=0.01,
527
+ label="Temperature",
528
+ interactive=True)
529
+ top_k = gr.Slider(1, 40, value=40, step=1,
530
+ label="Top K", interactive=True)
531
+ do_sample = gr.Checkbox(
532
+ value=True,
533
+ label="Do Sample",
534
+ info="use random sample strategy",
535
+ interactive=True)
536
+ repetition_penalty = gr.Slider(
537
+ 1.0,
538
+ 3.0,
539
+ value=1.1,
540
+ step=0.1,
541
+ label="Repetition Penalty",
542
+ interactive=True,
543
+ visible=False if args.use_vllm else True)
544
+ guidance_scale = gr.Slider(
545
+ 1.0,
546
+ 3.0,
547
+ value=1.0,
548
+ step=0.1,
549
+ label="Guidance Scale",
550
+ interactive=True,
551
+ visible=ENABLE_CFG_SAMPLING)
552
+ presence_penalty = gr.Slider(
553
+ -2.0,
554
+ 2.0,
555
+ value=1.0,
556
+ step=0.1,
557
+ label="Presence Penalty",
558
+ interactive=True,
559
+ visible=True if args.use_vllm else False)
560
+
561
+
562
+ params = [user_input, chatbot]
563
+ predict_params = [
564
+ chatbot,
565
+ system_prompt_input,
566
+ negative_prompt_input,
567
+ max_new_token,
568
+ top_p,
569
+ temperature,
570
+ top_k,
571
+ do_sample,
572
+ repetition_penalty,
573
+ guidance_scale,
574
+ presence_penalty]
575
+ with gr.Row():
576
+ gr.Markdown(
577
+ "免责声明:该模型可能会产生与事实不符的输出,不应依赖该模型来产生与事实相符的信息。模型在各种公共数据集以及得物一些商品信息进行训练。尽管做了大量的数据清洗,但是模型的输出结果还可能存在一些问题",
578
+ elem_classes=["disclaimer"],
579
+ )
580
+ submit_click_event = submitBtn.click(
581
+ user,
582
+ params,
583
+ params,
584
+ queue=False).then(
585
+ predict,
586
+ predict_params,
587
+ chatbot).then(
588
+ lambda: gr.update(
589
+ interactive=True),
590
+ None,
591
+ [user_input],
592
+ queue=True)
593
+
594
+
595
+ submit_event = user_input.submit(
596
+ user,
597
+ params,
598
+ params,
599
+ queue=False).then(
600
+ predict,
601
+ predict_params,
602
+ chatbot).then(
603
+ lambda: gr.update(
604
+ interactive=True),
605
+ None,
606
+ [user_input],
607
+ queue=True)
608
+
609
+ submitBtn.click(reset_user_input, [], [user_input])
610
+
611
+ stop.click(
612
+ fn=None,
613
+ inputs=None,
614
+ outputs=None,
615
+ cancels=[submit_event, submit_click_event],
616
+ queue=False,
617
+ )
618
+ emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True)
619
+
620
+
621
+ # Launch the Gradio interface
622
+ demo.queue().launch(
623
+ share=share,
624
+ inbrowser=True,
625
+ server_name='0.0.0.0',
626
+ server_port=port)