rynmurdock commited on
Commit
b1772c8
1 Parent(s): 2f230d1

maybe fixed :)

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -6,8 +6,7 @@ from sklearn.svm import LinearSVC
6
  from sklearn import preprocessing
7
  import pandas as pd
8
 
9
- from transformers import CLIPVisionModelWithProjection
10
- from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel, StableDiffusionXLPipeline
11
  from diffusers.models import ImageProjection
12
  from patch_sdxl import SDEmb
13
  import torch
@@ -40,16 +39,10 @@ unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="c
40
  pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
41
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
42
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
43
-
44
- image_preproc = CLIPImageProcessor.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='feature_extractor')
45
- image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='sdxl_models/image_encoder', torch_dtype=torch.float16)
46
- pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl.bin'), map_location="cpu"))
47
- pipe.register_modules(image_encoder = image_encoder)
48
- pipe.register_modules(feature_extractor = image_preproc)
49
-
50
  pipe.to(device='cuda')
 
51
 
52
- output_hidden_state = True
53
  #######################
54
 
55
  @spaces.GPU
@@ -64,7 +57,7 @@ def predict(
64
  im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda')
65
  image = pipe(
66
  prompt=prompt,
67
- ip_adapter_emb=[im_emb.to('cuda')],
68
  height=1024,
69
  width=1024,
70
  num_inference_steps=2,
@@ -124,7 +117,7 @@ def next_image(embs, ys, calibrate_prompts):
124
 
125
  w = 1# if len(embs) % 2 == 0 else 0
126
  im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
127
- prompt= '' if glob_idx % 2 == 0 else rng_prompt
128
  print(prompt)
129
  image, im_emb = predict(prompt, im_emb)
130
  embs.append(im_emb)
@@ -153,9 +146,9 @@ def start(_, embs, ys, calibrate_prompts):
153
 
154
 
155
  def choose(choice, embs, ys, calibrate_prompts):
156
- if choice == 'Like':
157
  choice = 1
158
- elif choice == 'Neither':
159
  _ = embs.pop(-1)
160
  img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
161
  return img, embs, ys, calibrate_prompts
 
6
  from sklearn import preprocessing
7
  import pandas as pd
8
 
9
+ from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel
 
10
  from diffusers.models import ImageProjection
11
  from patch_sdxl import SDEmb
12
  import torch
 
39
  pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
41
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
 
 
 
 
 
 
 
42
  pipe.to(device='cuda')
43
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
44
 
45
+ output_hidden_state = False
46
  #######################
47
 
48
  @spaces.GPU
 
57
  im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda')
58
  image = pipe(
59
  prompt=prompt,
60
+ ip_adapter_emb=im_emb.to('cuda'),
61
  height=1024,
62
  width=1024,
63
  num_inference_steps=2,
 
117
 
118
  w = 1# if len(embs) % 2 == 0 else 0
119
  im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
120
+ prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
121
  print(prompt)
122
  image, im_emb = predict(prompt, im_emb)
123
  embs.append(im_emb)
 
146
 
147
 
148
  def choose(choice, embs, ys, calibrate_prompts):
149
+ if choice == 'Like (L)':
150
  choice = 1
151
+ elif choice == 'Neither (Space)':
152
  _ = embs.pop(-1)
153
  img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
154
  return img, embs, ys, calibrate_prompts