Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import gradio as gr | |
import torch | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
import os | |
model_id = './model' | |
CUDA_AVAILABLE = torch.cuda.is_available() | |
device = torch.device("cuda" if CUDA_AVAILABLE else "cpu") | |
generator = pipeline('text-generation', model=model_id, | |
tokenizer=model_id, | |
load_in_8bit=True, | |
device=device) | |
early_stop_pattern = "\n\n\n" | |
print(f'Early stop pattern = \"{early_stop_pattern}\"') | |
model = generator.model | |
tok = generator.tokenizer | |
stop_token = tok.eos_token | |
print(f'stop_token = \"{stop_token}\"') | |
def generate(text = ""): | |
print("Create streamer") | |
yield "[ืื ื ืืืชืื ื ืืชืฉืืื]" | |
streamer = TextIteratorStreamer(tok, timeout=5.) | |
if len(text) == 0: | |
text = "\n" | |
inputs = tok([text], return_tensors="pt").to(device) | |
generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.5, do_sample=True, top_k=40, top_p=0.2, temperature=0.4, num_beams = 1 ,max_new_tokens=128, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
yield generated_text + new_text | |
print(new_text, end ="") | |
generated_text += new_text | |
if (early_stop_pattern in generated_text) or (stop_token in new_text): | |
generated_text = generated_text[: generated_text.find(early_stop_pattern) if early_stop_pattern else None] | |
generated_text = generated_text[: generated_text.find(stop_token) if stop_token else None] | |
streamer.end() | |
print("\n--\n") | |
yield generated_text | |
return generated_text | |
return generated_text | |
demo = gr.Interface( | |
title="Hebrew text generator: Science Fiction and Fantasy (GPT-Neo)", | |
fn=generate, | |
inputs=gr.Textbox(label="ืืชืื ืืื ืืช ืืืงืกื ืฉืืื ืื ืืฉืืืจื ืจืืง", elem_id="input_text"), | |
outputs=gr.Textbox(type="text", label="ืคื ืืืคืืข ืืืงืกื ืฉืืืืืื ืืืืื", elem_id="output_text"), | |
css="#output_text{direction: rtl} #input_text{direction: rtl}", | |
examples = ['ืืฉื ืืืคืืข ืืื','ืงืืื ืฉืืคื ืืช','ืคืขื ืืืช ืืคื ื ืฉื ืื ืจืืืช', 'ืืืจื ืคืืืจ ืืืื ืืืื ื ืืื', 'ืืื ืืคืจืชื ืืช ืื ืืืื ืืืงืก ืืฉ'], | |
allow_flagging="never" | |
) | |
demo.queue() | |
demo.launch() | |