ssboost commited on
Commit
9d23144
·
verified ·
1 Parent(s): b22c1c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -2,7 +2,7 @@ import spaces
2
  import random
3
  import torch
4
  from huggingface_hub import snapshot_download
5
- from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
6
  from kolors.pipelines import pipeline_stable_diffusion_xl_chatglm_256_ipadapter, pipeline_stable_diffusion_xl_chatglm_256
7
  from kolors.models.modeling_chatglm import ChatGLMModel
8
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
@@ -10,7 +10,6 @@ from kolors.models import unet_2d_condition
10
  from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
11
  import gradio as gr
12
  import numpy as np
13
- from huggingface_hub import InferenceClient
14
  import os
15
 
16
  device = "cuda"
@@ -27,6 +26,11 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_IPA_dir}/i
27
  ip_img_size = 336
28
  clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
29
 
 
 
 
 
 
30
  pipe_t2i = pipeline_stable_diffusion_xl_chatglm_256.StableDiffusionXLPipeline(
31
  vae=vae,
32
  text_encoder=text_encoder,
@@ -103,6 +107,13 @@ def infer(prompt,
103
  image.save("generated_image.jpg") # 파일 확장자를 .jpg로 변경
104
  return image, "generated_image.jpg"
105
 
 
 
 
 
 
 
 
106
  css="""
107
  #col-left {
108
  margin: 0 auto;
@@ -184,6 +195,7 @@ with gr.Blocks(css=css) as Kolors:
184
  with gr.Column(elem_id="col-right"):
185
  result = gr.Image(label="Result", show_label=False)
186
  download_button = gr.File(label="Download Image")
 
187
 
188
  # 이미지 생성 및 다운로드 파일 경로 설정
189
  run_button.click(
@@ -192,4 +204,11 @@ with gr.Blocks(css=css) as Kolors:
192
  outputs=[result, download_button]
193
  )
194
 
 
 
 
 
 
 
 
195
  Kolors.queue().launch(debug=True)
 
2
  import random
3
  import torch
4
  from huggingface_hub import snapshot_download
5
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor, CLIPModel, CLIPTokenizer
6
  from kolors.pipelines import pipeline_stable_diffusion_xl_chatglm_256_ipadapter, pipeline_stable_diffusion_xl_chatglm_256
7
  from kolors.models.modeling_chatglm import ChatGLMModel
8
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
 
10
  from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
11
  import gradio as gr
12
  import numpy as np
 
13
  import os
14
 
15
  device = "cuda"
 
26
  ip_img_size = 336
27
  clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
28
 
29
+ # CLIP 모델 및 토크나이저 로드
30
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
31
+ clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
32
+ clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
33
+
34
  pipe_t2i = pipeline_stable_diffusion_xl_chatglm_256.StableDiffusionXLPipeline(
35
  vae=vae,
36
  text_encoder=text_encoder,
 
107
  image.save("generated_image.jpg") # 파일 확장자를 .jpg로 변경
108
  return image, "generated_image.jpg"
109
 
110
+ def describe_image(image):
111
+ image = clip_processor(images=image, return_tensors="pt").to(device)
112
+ with torch.no_grad():
113
+ text_features = clip_model.get_image_features(**image)
114
+ text = clip_tokenizer.decode(torch.argmax(text_features, dim=-1))
115
+ return text
116
+
117
  css="""
118
  #col-left {
119
  margin: 0 auto;
 
195
  with gr.Column(elem_id="col-right"):
196
  result = gr.Image(label="Result", show_label=False)
197
  download_button = gr.File(label="Download Image")
198
+ image_description = gr.Textbox(label="Image Description", placeholder="이미지 분석 결과가 여기에 표시됩니다.", interactive=False)
199
 
200
  # 이미지 생성 및 다운로드 파일 경로 설정
201
  run_button.click(
 
204
  outputs=[result, download_button]
205
  )
206
 
207
+ # 이미지 설명 생성
208
+ ip_adapter_image.change(
209
+ fn=describe_image,
210
+ inputs=[ip_adapter_image],
211
+ outputs=[image_description]
212
+ )
213
+
214
  Kolors.queue().launch(debug=True)