BeveledCube commited on
Commit
f4b0576
·
verified ·
1 Parent(s): 75fe857

Create mamba.py

Browse files
Files changed (1) hide show
  1. models/mamba.py +16 -0
models/mamba.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ model_name = "s3nh/mamba-gpt-3b-v3-GGML"
4
+
5
+ def load():
6
+ global model
7
+ global tokenizer
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def generate(input_text):
13
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
+
16
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)