File size: 2,625 Bytes
363236f
 
00e184d
363236f
00e184d
 
 
363236f
00e184d
 
 
 
 
 
 
2eb7ad1
00e184d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2adb97c
 
00e184d
 
20e7da9
 
00e184d
 
 
45615d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# -*- coding: utf-8 -*-

import gradio as gr
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import os

model_id = './model'

CUDA_AVAILABLE = torch.cuda.is_available()
device = torch.device("cuda" if CUDA_AVAILABLE else "cpu")

generator = pipeline('text-generation', model=model_id,
                     tokenizer=model_id,   
                     load_in_8bit=True,
                     device=device)

early_stop_pattern = "\n\n\n"
print(f'Early stop pattern = \"{early_stop_pattern}\"')

model = generator.model
tok = generator.tokenizer 

stop_token = tok.eos_token
print(f'stop_token = \"{stop_token}\"')

def generate(text = ""):
  print("Create streamer")
  yield  "[ืื ื ื”ืžืชื™ื ื• ืœืชืฉื•ื‘ื”]"
  streamer = TextIteratorStreamer(tok, timeout=5.)
  if len(text) == 0:
    text = "\n"

  inputs = tok([text], return_tensors="pt").to(device)
  generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.5, do_sample=True, top_k=40, top_p=0.2, temperature=0.4, num_beams = 1 ,max_new_tokens=128, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4)
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
  thread.start()
  generated_text = ""
  for new_text in streamer:
    yield generated_text + new_text
    print(new_text, end ="")
    generated_text += new_text
    if (early_stop_pattern in generated_text) or (stop_token in new_text):
      generated_text = generated_text[: generated_text.find(early_stop_pattern) if early_stop_pattern else None]
      generated_text = generated_text[: generated_text.find(stop_token) if stop_token else None]
      streamer.end()
      print("\n--\n")
      yield generated_text
      return generated_text
  
  return generated_text

demo = gr.Interface(
    title="Hebrew text generator: Science Fiction and Fantasy (GPT-Neo)",
    fn=generate,
    inputs=gr.Textbox(label="ื›ืชื‘ื• ื›ืืŸ ืืช ื”ื˜ืงืกื˜ ืฉืœื›ื ืื• ื”ืฉืื™ืจื• ืจื™ืง", text_align = 'right', rtl = True, elem_id="input_text"),
    outputs=gr.Textbox(type="text", label="ืคื” ื™ื•ืคื™ืข ื”ื˜ืงืกื˜ ืฉื”ืžื—ื•ืœืœ ื™ื—ื•ืœืœ", text_align = 'right', rtl = True, elem_id="output_text"),
    css="#output_text{direction: rtl} #input_text{direction: rtl}",
    examples = ['ื”ืฉื“ ื”ื•ืคื™ืข ืžื•ืœ','ืงืืœื™ ืฉืœืคื” ืืช','ืคืขื ืื—ืช ืœืคื ื™ ืฉื ื™ื ืจื‘ื•ืช', 'ื”ืืจื™ ืคื•ื˜ืจ ื—ื™ื™ืš ื—ื™ื•ืš ื ื‘ื•ืš', 'ื•ืื– ื”ืคืจืชื™ ืืช ื›ืœ ื›ืœืœื™ ื”ื˜ืงืก ื›ืฉ'],
    allow_flagging="never",
    cache_examples=False
)

demo.queue()
demo.launch()