|
from transformers import GPT2Tokenizer, GPT2Model, pipeline, set_seed |
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("flax-community/swe-gpt-wiki") |
|
model = FlaxGPT2LMHeadModel.from_pretrained("flax-community/swe-gpt-wiki") |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") |
|
>>> outputs = model(**inputs, labels=inputs["input_ids"]) |
|
|
|
>>> logits = outputs.logits |
|
|
|
|
|
generator = pipeline('text-generation', model=model) |
|
set_seed(42) |
|
generator("Hej, jag är en språkmodell,", max_length=30, num_return_sequences=5) |
|
|