loooooong commited on
Commit
cf87904
1 Parent(s): 0a6a78c

fix import

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -19,7 +19,7 @@ from os.path import join as opj
19
  token = os.getenv("ACCESS_TOKEN")
20
  os.system(f"python -m pip install git+https://{token}@github.com/logn-2024/StableGarment.git")
21
 
22
- from stablegarment.models import AppearanceEncoderModel,ControlNetModel
23
  from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -32,7 +32,7 @@ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch_
32
  scheduler = UniPCMultistepScheduler.from_pretrained("runwayml/stable-diffusion-v1-5",subfolder="scheduler")
33
 
34
  pretrained_garment_encoder_path = "loooooong/StableGarment_text2img"
35
- garment_encoder = AppearanceEncoderModel.from_pretrained(pretrained_garment_encoder_path,torch_dtype=torch_dtype,subfolder="garment_encoder")
36
  garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype)
37
 
38
  pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype,).to(device=device) # variant="fp16"
 
19
  token = os.getenv("ACCESS_TOKEN")
20
  os.system(f"python -m pip install git+https://{token}@github.com/logn-2024/StableGarment.git")
21
 
22
+ from stablegarment.models import GarmentEncoderModel,ControlNetModel
23
  from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
32
  scheduler = UniPCMultistepScheduler.from_pretrained("runwayml/stable-diffusion-v1-5",subfolder="scheduler")
33
 
34
  pretrained_garment_encoder_path = "loooooong/StableGarment_text2img"
35
+ garment_encoder = GarmentEncoderModel.from_pretrained(pretrained_garment_encoder_path,torch_dtype=torch_dtype,subfolder="garment_encoder")
36
  garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype)
37
 
38
  pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype,).to(device=device) # variant="fp16"