Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
import functools as ft
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
import random
|
@@ -17,7 +15,6 @@ tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
|
|
17 |
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
|
18 |
model.to(device)
|
19 |
|
20 |
-
@ft.lru_cache(maxsize=1024)
|
21 |
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
|
22 |
if seed == 0:
|
23 |
seed = random.randint(1, 2**32-1)
|
@@ -30,22 +27,21 @@ def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model
|
|
30 |
|
31 |
model.to(dtype)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
|
48 |
-
|
49 |
|
50 |
|
51 |
your_prompt = gr.Textbox(label="Your Prompt", interactive=True)
|
@@ -62,7 +58,7 @@ top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, la
|
|
62 |
|
63 |
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
|
64 |
|
65 |
-
seed = gr.
|
66 |
|
67 |
examples = [
|
68 |
["A storefront with 'Text to Image' written on it.", 512, 1.2, 0.5, "fp16", 1, 50, 42]
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import random
|
|
|
15 |
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
|
16 |
model.to(device)
|
17 |
|
|
|
18 |
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
|
19 |
if seed == 0:
|
20 |
seed = random.randint(1, 2**32-1)
|
|
|
27 |
|
28 |
model.to(dtype)
|
29 |
|
30 |
+
input_text = f"Expand the following prompt to add more detail: {your_prompt}"
|
31 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
|
|
32 |
|
33 |
+
outputs = model.generate(
|
34 |
+
input_ids,
|
35 |
+
max_new_tokens=max_new_tokens,
|
36 |
+
repetition_penalty=repetition_penalty,
|
37 |
+
do_sample=True,
|
38 |
+
temperature=temperature,
|
39 |
+
top_p=top_p,
|
40 |
+
top_k=top_k,
|
41 |
+
)
|
42 |
|
43 |
+
better_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
44 |
+
return better_prompt
|
45 |
|
46 |
|
47 |
your_prompt = gr.Textbox(label="Your Prompt", interactive=True)
|
|
|
58 |
|
59 |
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
|
60 |
|
61 |
+
seed = gr.Slider(value=42, minimum=0, maximum=2**32-1, interactive=True, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
|
62 |
|
63 |
examples = [
|
64 |
["A storefront with 'Text to Image' written on it.", 512, 1.2, 0.5, "fp16", 1, 50, 42]
|