Nick088 commited on
Commit
c435396
·
verified ·
1 Parent(s): 7c64f78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -18
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import functools as ft
2
-
3
  import gradio as gr
4
  import torch
5
  import random
@@ -17,7 +15,6 @@ tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
17
  model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
18
  model.to(device)
19
 
20
- @ft.lru_cache(maxsize=1024)
21
  def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
22
  if seed == 0:
23
  seed = random.randint(1, 2**32-1)
@@ -30,22 +27,21 @@ def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model
30
 
31
  model.to(dtype)
32
 
33
- with torch.inference_mode():
34
- input_text = f"Expand the following prompt to add more detail: {your_prompt}"
35
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
36
 
37
- outputs = model.generate(
38
- input_ids,
39
- max_new_tokens=max_new_tokens,
40
- repetition_penalty=repetition_penalty,
41
- do_sample=True,
42
- temperature=temperature,
43
- top_p=top_p,
44
- top_k=top_k,
45
- )
46
 
47
- better_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
- return better_prompt
49
 
50
 
51
  your_prompt = gr.Textbox(label="Your Prompt", interactive=True)
@@ -62,7 +58,7 @@ top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, la
62
 
63
  top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
64
 
65
- seed = gr.Number(value=42, interactive=True, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
66
 
67
  examples = [
68
  ["A storefront with 'Text to Image' written on it.", 512, 1.2, 0.5, "fp16", 1, 50, 42]
 
 
 
1
  import gradio as gr
2
  import torch
3
  import random
 
15
  model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
16
  model.to(device)
17
 
 
18
  def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
19
  if seed == 0:
20
  seed = random.randint(1, 2**32-1)
 
27
 
28
  model.to(dtype)
29
 
30
+ input_text = f"Expand the following prompt to add more detail: {your_prompt}"
31
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
 
32
 
33
+ outputs = model.generate(
34
+ input_ids,
35
+ max_new_tokens=max_new_tokens,
36
+ repetition_penalty=repetition_penalty,
37
+ do_sample=True,
38
+ temperature=temperature,
39
+ top_p=top_p,
40
+ top_k=top_k,
41
+ )
42
 
43
+ better_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
+ return better_prompt
45
 
46
 
47
  your_prompt = gr.Textbox(label="Your Prompt", interactive=True)
 
58
 
59
  top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
60
 
61
+ seed = gr.Slider(value=42, minimum=0, maximum=2**32-1, interactive=True, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
62
 
63
  examples = [
64
  ["A storefront with 'Text to Image' written on it.", 512, 1.2, 0.5, "fp16", 1, 50, 42]