# -*- coding: utf-8 -*- import gradio as gr import torch from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import os model_id = './model' CUDA_AVAILABLE = torch.cuda.is_available() device = torch.device("cuda" if CUDA_AVAILABLE else "cpu") generator = pipeline('text-generation', model=model_id, tokenizer=model_id, load_in_8bit=True, device=device) early_stop_pattern = "\n\n\n" print(f'Early stop pattern = \"{early_stop_pattern}\"') model = generator.model tok = generator.tokenizer stop_token = tok.eos_token print(f'stop_token = \"{stop_token}\"') def generate(text = ""): print("Create streamer") yield "[אנא המתינו לתשובה]" streamer = TextIteratorStreamer(tok, timeout=5.) if len(text) == 0: text = "\n" inputs = tok([text], return_tensors="pt").to(device) generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.5, do_sample=True, top_k=40, top_p=0.2, temperature=0.4, num_beams = 1 ,max_new_tokens=128, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: yield generated_text + new_text print(new_text, end ="") generated_text += new_text if (early_stop_pattern in generated_text) or (stop_token in new_text): generated_text = generated_text[: generated_text.find(early_stop_pattern) if early_stop_pattern else None] generated_text = generated_text[: generated_text.find(stop_token) if stop_token else None] streamer.end() print("\n--\n") yield generated_text return generated_text return generated_text demo = gr.Interface( title="Hebrew text generator: Science Fiction and Fantasy (GPT-Neo)", fn=generate, inputs=gr.Textbox(label="כתבו כאן את הטקסט שלכם או השאירו ריק", text_align = 'right', rtl = True, elem_id="input_text"), outputs=gr.Textbox(type="text", label="פה יופיע הטקסט שהמחולל יחולל", text_align = 'right', rtl = True, elem_id="output_text"), css="#output_text{direction: rtl} #input_text{direction: rtl}", examples = ['השד הופיע מול','קאלי שלפה את','פעם אחת לפני שנים רבות', 'הארי פוטר חייך חיוך נבוך', 'ואז הפרתי את כל כללי הטקס כש'], allow_flagging="never", cache_examples=False ) demo.queue() demo.launch()