|
|
|
import torch |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel,GPT2Model |
|
from transformers import TFAutoModel |
|
|
|
model_path = "./trained_gpt2_jokes/5/" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = GPT2LMHeadModel.from_pretrained(model_path, from_tf=True).to(device) |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_path) |
|
|
|
model.eval() |
|
|
|
def generate_text(prompt_text): |
|
|
|
input_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, temperature=0.5) |
|
|
|
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return decoded_output |
|
|
|
prompt = "Why man are " |
|
generated_joke = generate_text(prompt) |
|
print(generated_joke) |
|
|
|
|
|
|
|
|
|
|