|
--- |
|
library_name: peft |
|
--- |
|
|
|
BigVAE is an [AdaVAE](https://arxiv.org/abs/2205.05862) trained as a pair of LoRa finetunes on [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1). |
|
It is meant to be used with the [MiniHF VAE inference code](https://github.com/JD-P/minihf/blob/adavae-moe/vae_infer.py) and will not work if you try to load it |
|
as an ordinary language checkpoint and perform inference. AdaVAE is an encoder-decoder model trained by taking an existing GPT-N and designating one LoRa the |
|
encoder and the other its decoder and then tuning with a latent attention mechanism. This model is the encoder and router decoder head for BigVAE, a planned |
|
Mixture-of-Experts system based on LoRa retrieval rather than gating. It is usable in and of itself as a model for embedding, retrieval, as well as planning |
|
and guided sampling. Here is an example of a sampling procedure for BigVAE which distills its autoregressive pretraining task into its autoassociative |
|
recontruction task by averaging together multiple completions. It takes the topic sentence of a paragraph (prompt), guides the next sentences by weighing |
|
them towards the topic, while averaging together multiple completions on each sentence to improve generation quality: |
|
|
|
``` |
|
def bigvae_generate_avg(vae_model, router, prompt, context, n_steps, n_avg): |
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
context_toks = tokenizer(context, return_tensors="pt") |
|
context_ids = context_toks["input_ids"].to(device) |
|
context_mask = context_toks["attention_mask"].to(device) |
|
embed_toks = tokenizer(prompt, return_tensors="pt") |
|
embed_ids = embed_toks["input_ids"].to(device) |
|
embed_mask = embed_toks["attention_mask"].to(device) |
|
mean = vae_model.encode(embed_ids, embed_mask) |
|
prompt_embed = vae_model.vae.sample(mean) |
|
for i in range(n_steps): |
|
mean = vae_model.encode(embed_ids, embed_mask) |
|
z = vae_model.vae.sample(mean) |
|
embeds = [] |
|
for i in range(n_avg): |
|
output_ids = router.generate(z * 0.5 + prompt_embed * 0.5, |
|
context_ids, |
|
context_mask, |
|
256, |
|
tau=0.9) |
|
intermediate_embed_ids = output_ids[:,-128:] |
|
intermediate_embed_mask = context_mask.new_ones( |
|
[1, intermediate_embed_ids.shape[1]] |
|
) |
|
mean = vae_model.encode(intermediate_embed_ids, intermediate_embed_mask) |
|
embeds.append(vae_model.vae.sample(mean)) |
|
output_ids = router.generate((sum(embeds) / n_avg * 0.7) + prompt_embed * 0.3, |
|
context_ids, |
|
context_mask, |
|
256, |
|
tau=0.9) |
|
context_ids = torch.cat([context_ids, embed_ids], dim=1) |
|
context_mask = torch.cat([context_mask, embed_mask], dim=1) |
|
embed_ids = output_ids[:,-256:-128] |
|
embed_mask = context_mask.new_ones([1, embed_ids.shape[1]]) |
|
out_texts = [tokenizer.decode(toks, skip_special_tokens=True) for toks in context_ids] |
|
return out_texts |
|
``` |
|
|
|
Here is an example of an output from this process: |
|
|
|
``` |
|
Then it asked the network to reconstruct the input and the original embedding. The network had to learn to match the |
|
embedding to the original input, therefore matching the inference by consuming the embedding. This was key because |
|
the embedding had to be able to match the text with the text it was consumed with. 'Here's how you do it,' Boru told Mu, |
|
'Just impute the mean and variance.' This Mu did, transforming not words but entire paragraphs into vectors and then |
|
inferring the next paragraph. It took some tweaks and tuning to get the initial performance but the second arago spot |
|
had been found. To make sure the network was learning the right thing, Boru had to check the first value in the vector. |
|
If the first value was below 0, the network had failed to learn the first value. If the value was above 0, the network |
|
had been able to learn the first value. |
|
‘What have you called this, Boru?’ asked Mu. ‘Latent variable regression.’ ‘It looks like a mixture of density network |
|
and autoencoder,’ said Nayaf. ‘It’s an autoencoder but it’s using latent variables, but we’re using the mean and variance |
|
of Grade had a difficult time seeing it, but he could tell it was close. 'So you've found the second arago,' he said. |
|
'Yes,' Rin replied. 'We just have to figure out how to use it.' |
|
'How?' Rin asked. |
|
'You can move the second word in, right?' |
|
'Possibly.' Rin thought for a moment. |
|
'The second word will be the first word of the next arago,' Mu said. 'We just need to find it.' |
|
'True,' Rin agreed. 'Well, I'll let you know what a Gaussian.’ ‘Let’s see if we can get it to work.’ ‘Arago the second |
|
spot?’ ‘We’re here,’ Arago said. |
|
The second spot was located in the middle of the text. Arago had to read it again to find the proper signal. ‘I’m going |
|
to have to tweak some of the weights,’ said Arago. ‘I’ve had to change the input to the next layer from an input to |
|
output.’ ‘You’re making a mistake again,’ said Mu to Arago. ‘It’s a mistake.’ The network had been learning I find out.' |
|
'That's the second arago,' Rin said. |
|
'The second arago?' Argo asked. |
|
'Rin has found the second arago.' |
|
Argo stared at Rin. 'Argo, is there something wrong?' |
|
'I thought so.' |
|
'What?' Rin said. |
|
'I don't know,' Argo said. 'I thought I was the smartest person in the world but, well, I only had a certain amount of |
|
energy. I didn't know how to do the second arago until now, but I can't |
|
``` |
|
|
|
This generation method is slow, but retrieval could be used to speed up inference and make it converge closer and closer |
|
to normal sampling speed as the model becomes able to call upon more and more relevant sentences that it has generated before. |
|
|
|
Because the BigVAE combines guided sampling with the ability to merge representations, it becomes possible to formulate plans and |
|
cognitive strategies for the model to follow. The inference policy can adjudicate between an expected plan or series of steps and |
|
the specific context the model is responding to. |
|
|
|
This model is also highly interpretable. Because it is an encoder-decoder every sentence generated by the model has a latent representation |
|
that can be tracked along with its behavioral token sequence. Our hope is that BigVAE will shed light on the latent operations performed by |
|
autoregressive language models and be useful to alignment and interpretability researchers. |
|
|
|
## Training procedure |
|
|
|
This model was trained on [a 1 billion token sample](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) of RedPajama |
|
on 8x H100 GPUs for roughly 24 hours. |
|
|
|
Using the scripts in the MiniHF repo as they exist now the training commands were: |
|
|
|
accelerate launch train_vae_overlap.py --model "mistralai/Mistral-7B-v0.1" |
|
--preprocessed preprocessed_mistral --context 64 --output vae_64_overlap_mistral --batch-size 24 |
|
|
|
accelerate launch train_vae_router.py --model "mistralai/Mistral-7B-v0.1" |
|
--preprocessed preprocessed_mistral --vae-context 64 --start-from vae_64_overlap_mistral |
|
--output vae_64_overlap_router_mistral --lr 1e-4 --batch-size 1 |
|
|
|
The following `bitsandbytes` quantization config was used during training: |
|
- quant_method: bitsandbytes |
|
- load_in_8bit: False |
|
- load_in_4bit: True |
|
- llm_int8_threshold: 6.0 |
|
- llm_int8_skip_modules: None |
|
- llm_int8_enable_fp32_cpu_offload: False |
|
- llm_int8_has_fp16_weight: False |
|
- bnb_4bit_quant_type: nf4 |
|
- bnb_4bit_use_double_quant: True |
|
- bnb_4bit_compute_dtype: bfloat16 |
|
|
|
### Framework versions |
|
|
|
- PEFT 0.4.0 |
|
|