mamba / app.py
Vaibhav Srivastav
up
db4c88c
raw
history blame
1.11 kB
import torch
import torch.nn.functional as F
from einops import rearrange
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device=device, dtype=torch.float16)
def pred(text_in):
tokens = tokenizer(text_in, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + 100
fn = lambda: model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=1.0,
top_k=1,
top_p=1.0,
)
out = fn()
text_out = tokenizer.batch_decode(out.sequences.tolist())
return text_out
demo = gr.Interface(fn=pred, inputs="text", outputs="text")
if __name__ == "__main__":
demo.launch()