Norod78's picture
Update app.py
2adb97c verified
raw
history blame contribute delete
No virus
2.63 kB
# -*- coding: utf-8 -*-
import gradio as gr
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import os
model_id = './model'
CUDA_AVAILABLE = torch.cuda.is_available()
device = torch.device("cuda" if CUDA_AVAILABLE else "cpu")
generator = pipeline('text-generation', model=model_id,
tokenizer=model_id,
load_in_8bit=True,
device=device)
early_stop_pattern = "\n\n\n"
print(f'Early stop pattern = \"{early_stop_pattern}\"')
model = generator.model
tok = generator.tokenizer
stop_token = tok.eos_token
print(f'stop_token = \"{stop_token}\"')
def generate(text = ""):
print("Create streamer")
yield "[ืื ื ื”ืžืชื™ื ื• ืœืชืฉื•ื‘ื”]"
streamer = TextIteratorStreamer(tok, timeout=5.)
if len(text) == 0:
text = "\n"
inputs = tok([text], return_tensors="pt").to(device)
generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.5, do_sample=True, top_k=40, top_p=0.2, temperature=0.4, num_beams = 1 ,max_new_tokens=128, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
yield generated_text + new_text
print(new_text, end ="")
generated_text += new_text
if (early_stop_pattern in generated_text) or (stop_token in new_text):
generated_text = generated_text[: generated_text.find(early_stop_pattern) if early_stop_pattern else None]
generated_text = generated_text[: generated_text.find(stop_token) if stop_token else None]
streamer.end()
print("\n--\n")
yield generated_text
return generated_text
return generated_text
demo = gr.Interface(
title="Hebrew text generator: Science Fiction and Fantasy (GPT-Neo)",
fn=generate,
inputs=gr.Textbox(label="ื›ืชื‘ื• ื›ืืŸ ืืช ื”ื˜ืงืกื˜ ืฉืœื›ื ืื• ื”ืฉืื™ืจื• ืจื™ืง", text_align = 'right', rtl = True, elem_id="input_text"),
outputs=gr.Textbox(type="text", label="ืคื” ื™ื•ืคื™ืข ื”ื˜ืงืกื˜ ืฉื”ืžื—ื•ืœืœ ื™ื—ื•ืœืœ", text_align = 'right', rtl = True, elem_id="output_text"),
css="#output_text{direction: rtl} #input_text{direction: rtl}",
examples = ['ื”ืฉื“ ื”ื•ืคื™ืข ืžื•ืœ','ืงืืœื™ ืฉืœืคื” ืืช','ืคืขื ืื—ืช ืœืคื ื™ ืฉื ื™ื ืจื‘ื•ืช', 'ื”ืืจื™ ืคื•ื˜ืจ ื—ื™ื™ืš ื—ื™ื•ืš ื ื‘ื•ืš', 'ื•ืื– ื”ืคืจืชื™ ืืช ื›ืœ ื›ืœืœื™ ื”ื˜ืงืก ื›ืฉ'],
allow_flagging="never",
cache_examples=False
)
demo.queue()
demo.launch()