PaLM_models / app.py
divi7007's picture
Update app.py
122ca63
raw
history blame
1.76 kB
import torch
from transformers import AutoTokenizer
from palm_rlhf_pytorch import PaLM
import gradio as gr
def generate(prompt, seq_len=128, temperature=0.8, filter_thres=0.9):
device = torch.device("cpu")
num_tokens = 50304
dim = 2048
depth = 16
dim_head = 128
heads = 8
flash_attn = True
# model = PaLM(
# num_tokens=num_tokens, dim=dim, depth=depth, dim_head=dim_head, heads=heads, flash_attn=flash_attn
# ).to(device).eval()
model = PaLM(
num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False,
).to(device).eval()
checkpoint = torch.load('./palm_410m_8k_v0.pt', map_location=device)
model.load_state_dict(checkpoint)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
encoded_text = tokenizer(prompt, return_tensors="pt")
output_tensor = model.generate(
seq_len=seq_len,
prompt=encoded_text["input_ids"].to(device),
temperature=temperature,
filter_thres=filter_thres,
pad_value=0.0,
eos_token=tokenizer.eos_token_id,
return_seq_without_prompt=False,
use_tqdm=True,
)
decoded_output = tokenizer.batch_decode(output_tensor[0], skip_special_tokens=True)
return decoded_output
iface = gr.Interface(
fn=generate,
title="PaLM",
description="Open-source PaLM demo.",
inputs="text",
outputs="text",
# seq_len=gr.Slider(minimum=1, maximum=8192, step=1, default=32, label="Sequence Length"),
# temperature=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.8, label="Temperature"),
# filter_thres=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.9, label="Filter Threshold"),
)
iface.launch()