Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -10,8 +10,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
10 |
|
11 |
device = torch.device("cuda:0")
|
12 |
|
13 |
-
llm = AutoModelForCausalLM.from_pretrained("Azure99/blossom-v5-
|
14 |
-
tokenizer = AutoTokenizer.from_pretrained("Azure99/blossom-v5-
|
15 |
diffusion_pipe = DiffusionPipeline.from_pretrained(
|
16 |
"playgroundai/playground-v2.5-1024px-aesthetic",
|
17 |
torch_dtype=torch.float16,
|
@@ -34,7 +34,7 @@ def save_image(img):
|
|
34 |
|
35 |
|
36 |
LLM_PROMPT = '''你的任务是从输入的[作画要求]中抽取画面描述(description),然后description翻译为英文(en_description),最后对en_description进行扩写(expanded_description),增加足够多的细节,且符合人类的第一直觉。
|
37 |
-
[输出]是一个json,包含description、en_description、expanded_description
|
38 |
|
39 |
下面是一些示例:
|
40 |
[作画要求]->"画一幅画:落霞与孤鹜齐飞,秋水共长天一色。"
|
@@ -60,22 +60,29 @@ def generate(
|
|
60 |
prompt: str,
|
61 |
progress=gr.Progress(track_tqdm=True),
|
62 |
):
|
63 |
-
|
64 |
input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX)
|
65 |
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
|
66 |
max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
|
67 |
llm_result = llm.generate(**generation_kwargs)
|
68 |
llm_result = llm_result.cpu()[0][len(input_ids):]
|
69 |
llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
|
|
|
|
|
70 |
print(llm_result)
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
seed = random.randint(0, 2147483647)
|
75 |
generator = torch.Generator().manual_seed(seed)
|
76 |
|
77 |
images = diffusion_pipe(
|
78 |
-
prompt=
|
79 |
negative_prompt=None,
|
80 |
width=1024,
|
81 |
height=1024,
|
@@ -107,7 +114,7 @@ with gr.Blocks(css=css) as demo:
|
|
107 |
container=False,
|
108 |
)
|
109 |
run_button = gr.Button("Run", scale=0)
|
110 |
-
result = gr.Gallery(label="Result", columns=1, show_label=False)
|
111 |
|
112 |
gr.on(
|
113 |
triggers=[
|
|
|
10 |
|
11 |
device = torch.device("cuda:0")
|
12 |
|
13 |
+
llm = AutoModelForCausalLM.from_pretrained("Azure99/blossom-v5-4b", torch_dtype=torch.float16, device_map="auto")
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("Azure99/blossom-v5-4b")
|
15 |
diffusion_pipe = DiffusionPipeline.from_pretrained(
|
16 |
"playgroundai/playground-v2.5-1024px-aesthetic",
|
17 |
torch_dtype=torch.float16,
|
|
|
34 |
|
35 |
|
36 |
LLM_PROMPT = '''你的任务是从输入的[作画要求]中抽取画面描述(description),然后description翻译为英文(en_description),最后对en_description进行扩写(expanded_description),增加足够多的细节,且符合人类的第一直觉。
|
37 |
+
[输出]是一个json,包含description、en_description、expanded_description三个字符串字段,请直接输出一个完整的json,不要输出任何解释或其他无关内容。
|
38 |
|
39 |
下面是一些示例:
|
40 |
[作画要求]->"画一幅画:落霞与孤鹜齐飞,秋水共长天一色。"
|
|
|
60 |
prompt: str,
|
61 |
progress=gr.Progress(track_tqdm=True),
|
62 |
):
|
|
|
63 |
input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX)
|
64 |
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
|
65 |
max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
|
66 |
llm_result = llm.generate(**generation_kwargs)
|
67 |
llm_result = llm_result.cpu()[0][len(input_ids):]
|
68 |
llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
|
69 |
+
print("----------")
|
70 |
+
print(prompt)
|
71 |
print(llm_result)
|
72 |
+
en_prompt = prompt
|
73 |
+
expanded_prompt = prompt
|
74 |
+
try:
|
75 |
+
en_prompt = json.loads(llm_result)["en_description"]
|
76 |
+
expanded_prompt = json.loads(llm_result)["expanded_description"]
|
77 |
+
except:
|
78 |
+
print("error, fallback to original prompt")
|
79 |
+
pass
|
80 |
|
81 |
seed = random.randint(0, 2147483647)
|
82 |
generator = torch.Generator().manual_seed(seed)
|
83 |
|
84 |
images = diffusion_pipe(
|
85 |
+
prompt=[en_prompt, expanded_prompt],
|
86 |
negative_prompt=None,
|
87 |
width=1024,
|
88 |
height=1024,
|
|
|
114 |
container=False,
|
115 |
)
|
116 |
run_button = gr.Button("Run", scale=0)
|
117 |
+
result = gr.Gallery(label="Result", columns=2, rows=1, show_label=False)
|
118 |
|
119 |
gr.on(
|
120 |
triggers=[
|