Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer | |
from palm_rlhf_pytorch import PaLM | |
import gradio as gr | |
def generate(prompt, seq_len=128, temperature=0.8, filter_thres=0.9): | |
device = torch.device("cpu") | |
num_tokens = 50304 | |
dim = 2048 | |
depth = 16 | |
dim_head = 128 | |
heads = 8 | |
flash_attn = True | |
# model = PaLM( | |
# num_tokens=num_tokens, dim=dim, depth=depth, dim_head=dim_head, heads=heads, flash_attn=flash_attn | |
# ).to(device).eval() | |
model = PaLM( | |
num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False, | |
).to(device).eval() | |
checkpoint = torch.load('./palm_410m_8k_v0.pt', map_location=device) | |
model.load_state_dict(checkpoint) | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") | |
encoded_text = tokenizer(prompt, return_tensors="pt") | |
output_tensor = model.generate( | |
seq_len=seq_len, | |
prompt=encoded_text["input_ids"].to(device), | |
temperature=temperature, | |
filter_thres=filter_thres, | |
pad_value=0.0, | |
eos_token=tokenizer.eos_token_id, | |
return_seq_without_prompt=False, | |
use_tqdm=True, | |
) | |
decoded_output = tokenizer.batch_decode(output_tensor[0], skip_special_tokens=True) | |
return decoded_output | |
iface = gr.Interface( | |
fn=generate, | |
title="PaLM", | |
description="Open-source PaLM demo.", | |
inputs="text", | |
outputs="text", | |
# seq_len=gr.Slider(minimum=1, maximum=8192, step=1, default=32, label="Sequence Length"), | |
# temperature=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.8, label="Temperature"), | |
# filter_thres=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.9, label="Filter Threshold"), | |
) | |
iface.launch() |