File size: 528 Bytes
455a257
 
 
de02b26
 
 
 
 
 
 
455a257
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from transformers import GPT2Tokenizer, GPT2Model, pipeline, set_seed
 
tokenizer = GPT2Tokenizer.from_pretrained("flax-community/swe-gpt-wiki")
model = FlaxGPT2LMHeadModel.from_pretrained("flax-community/swe-gpt-wiki")

>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
>>> outputs = model(**inputs, labels=inputs["input_ids"])

>>> logits = outputs.logits


generator = pipeline('text-generation', model=model)
set_seed(42)
generator("Hej, jag är en språkmodell,", max_length=30, num_return_sequences=5)