Somunia commited on
Commit
68918ad
·
verified ·
1 Parent(s): 306b4ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -20
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import time
 
4
 
5
  def generate_prompt(instruction, input=""):
6
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -25,10 +26,10 @@ model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_path,
27
  trust_remote_code=True,
28
- use_flash_attention_2=False # Explicitly disable Flash Attention
29
  ).to(torch.float32)
30
 
31
-
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  model_path,
34
  bos_token="</s>",
@@ -40,23 +41,41 @@ tokenizer = AutoTokenizer.from_pretrained(
40
  clean_up_tokenization_spaces=False # Or set to True if you prefer
41
  )
42
 
43
- print(tokenizer.special_tokens_map)
44
-
45
- text = "Hi"
46
-
47
- prompt = generate_prompt(text)
48
-
49
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
50
-
51
- # Generate text word by word with stop sequence
52
- generated_text = ""
53
- for i in range(333): # Generate up to 333 tokens
54
- output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0)
55
- new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True)
56
-
57
- print(new_word, end="", flush=True) # Print word-by-word
58
- generated_text += new_word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- input_ids = output # Update input_ids for next iteration
 
61
 
62
- print() # Add a newline at the end
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import time
4
+ import gradio as gr
5
 
6
  def generate_prompt(instruction, input=""):
7
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_path,
28
  trust_remote_code=True,
29
+ use_flash_attention_2=False
30
  ).to(torch.float32)
31
 
32
+ # Create a custom tokenizer (make sure to download vocab.json)
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  model_path,
35
  bos_token="</s>",
 
41
  clean_up_tokenization_spaces=False # Or set to True if you prefer
42
  )
43
 
44
+ # Function to handle text generation with word-by-word output and stop sequence
45
+ def generate_text(input_text):
46
+ prompt = generate_prompt(input_text)
47
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
48
+
49
+ generated_text = ""
50
+ stop_sequence_found = False
51
+ for i in range(333):
52
+ output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0)
53
+ new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True)
54
+
55
+ print(new_word, end="", flush=True)
56
+ generated_text += new_word
57
+
58
+ if new_word == '\n' or new_word == '.':
59
+ stop_sequence_found = True
60
+ break
61
+
62
+ input_ids = output
63
+
64
+ if stop_sequence_found:
65
+ print("\n(Stop sequence found)")
66
+ print()
67
+ return generated_text
68
+
69
+ # Create the Gradio interface
70
+ iface = gr.Interface(
71
+ fn=generate_text,
72
+ inputs="text",
73
+ outputs="text",
74
+ title="RWKV Chatbot",
75
+ description="Enter your prompt below:",
76
+ )
77
 
78
+ # For local testing:
79
+ # iface.launch()
80
 
81
+ # Hugging Face Spaces will automatically launch the interface.