julien-c HF staff commited on
Commit
82612ce
·
verified ·
1 Parent(s): 739a14e

adapt handler from https://huggingface.co/spaces/cfahlgren1/Emoji-Generator-by-fofr

Browse files
Files changed (2) hide show
  1. handler.py +40 -16
  2. requirements.txt +88 -0
handler.py CHANGED
@@ -1,23 +1,37 @@
1
- from typing import Dict, List, Any
2
  import torch
3
  from torch import autocast
4
- from diffusers import StableDiffusionPipeline
 
5
  import base64
6
  from io import BytesIO
 
7
 
8
 
9
- # set device
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
- if device.type != 'cuda':
13
- raise ValueError("need to run on GPU")
14
 
15
- class EndpointHandler():
16
  def __init__(self, path=""):
17
- # load the optimized model
18
- self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
19
- self.pipe = self.pipe.to(device)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
23
  """
@@ -28,15 +42,25 @@ class EndpointHandler():
28
  A :obj:`dict`:. base64 encoded image
29
  """
30
  inputs = data.pop("inputs", data)
31
-
32
- # run inference pipeline
33
- with autocast(device.type):
34
- image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]
35
-
 
 
 
36
  # encode image as base 64
37
  buffered = BytesIO()
38
  image.save(buffered, format="JPEG")
39
  img_str = base64.b64encode(buffered.getvalue())
40
 
41
  # postprocess the prediction
42
- return {"image": img_str.decode()}
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
  import torch
3
  from torch import autocast
4
+ from huggingface_hub import hf_hub_download
5
+ from diffusers import DiffusionPipeline
6
  import base64
7
  from io import BytesIO
8
+ from cog_sdxl.dataset_and_utils import TokenEmbeddingsHandler
9
 
10
 
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print("device ~>", device)
13
 
 
 
14
 
15
+ class EndpointHandler:
16
  def __init__(self, path=""):
17
+ print("path ~>", path)
 
 
18
 
19
+ self.pipe = DiffusionPipeline.from_pretrained(
20
+ "stabilityai/stable-diffusion-xl-base-1.0",
21
+ torch_dtype=torch.float16 if device.type == "cuda" else None,
22
+ variant="fp16",
23
+ ).to(device)
24
+
25
+ self.pipe.load_lora_weights("SvenN/sdxl-emoji", weight_name="lora.safetensors")
26
+
27
+ text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
28
+ tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
29
+
30
+ embedding_path = hf_hub_download(
31
+ repo_id="SvenN/sdxl-emoji", filename="embeddings.pti", repo_type="model"
32
+ )
33
+ embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
34
+ embhandler.load_embeddings(embedding_path)
35
 
36
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
37
  """
 
42
  A :obj:`dict`:. base64 encoded image
43
  """
44
  inputs = data.pop("inputs", data)
45
+
46
+ # Automatically add trigger tokens to the beginning of the prompt
47
+ full_prompt = f"A <s0><s1> {inputs}"
48
+ images = self.pipe(
49
+ full_prompt,
50
+ cross_attention_kwargs={"scale": 0.8},
51
+ ).images
52
+ image = images[0]
53
  # encode image as base 64
54
  buffered = BytesIO()
55
  image.save(buffered, format="JPEG")
56
  img_str = base64.b64encode(buffered.getvalue())
57
 
58
  # postprocess the prediction
59
+ return {"image": img_str.decode()}
60
+
61
+
62
+ if __name__ == "__main__":
63
+ handler = EndpointHandler()
64
+ print(handler)
65
+ output = handler({"inputs": "emoji of a tiger face, white background"})
66
+ print(output)
requirements.txt ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.31.0
2
+ aiofiles==23.2.1
3
+ altair==5.3.0
4
+ annotated-types==0.7.0
5
+ anyio==4.4.0
6
+ attrs==23.2.0
7
+ certifi==2024.6.2
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ contourpy==1.2.1
11
+ cycler==0.12.1
12
+ diffusers==0.29.1
13
+ dnspython==2.6.1
14
+ email_validator==2.2.0
15
+ exceptiongroup==1.2.1
16
+ fastapi==0.111.0
17
+ fastapi-cli==0.0.4
18
+ ffmpy==0.3.2
19
+ filelock==3.15.4
20
+ fonttools==4.53.0
21
+ fsspec==2024.6.0
22
+ gradio==4.36.1
23
+ gradio_client==1.0.1
24
+ h11==0.14.0
25
+ httpcore==1.0.5
26
+ httptools==0.6.1
27
+ httpx==0.27.0
28
+ huggingface-hub==0.23.4
29
+ idna==3.7
30
+ importlib_metadata==7.2.0
31
+ importlib_resources==6.4.0
32
+ Jinja2==3.1.4
33
+ jsonschema==4.22.0
34
+ jsonschema-specifications==2023.12.1
35
+ kiwisolver==1.4.5
36
+ markdown-it-py==3.0.0
37
+ MarkupSafe==2.1.5
38
+ matplotlib==3.9.0
39
+ mdurl==0.1.2
40
+ mpmath==1.3.0
41
+ networkx==3.3
42
+ numpy==1.23.5
43
+ orjson==3.10.5
44
+ packaging==24.1
45
+ pandas==2.2.2
46
+ peft==0.11.1
47
+ pillow==10.3.0
48
+ psutil==5.9.8
49
+ pydantic==2.7.4
50
+ pydantic_core==2.18.4
51
+ pydub==0.25.1
52
+ Pygments==2.18.0
53
+ pyparsing==3.1.2
54
+ python-dateutil==2.9.0.post0
55
+ python-dotenv==1.0.1
56
+ python-multipart==0.0.9
57
+ pytz==2024.1
58
+ PyYAML==6.0.1
59
+ referencing==0.35.1
60
+ regex==2024.5.15
61
+ requests==2.32.3
62
+ rich==13.7.1
63
+ rpds-py==0.18.1
64
+ ruff==0.4.10
65
+ safetensors==0.4.3
66
+ semantic-version==2.10.0
67
+ shellingham==1.5.4
68
+ six==1.16.0
69
+ sniffio==1.3.1
70
+ spaces==0.28.3
71
+ starlette==0.37.2
72
+ sympy==1.12.1
73
+ tokenizers==0.19.1
74
+ tomlkit==0.12.0
75
+ toolz==0.12.1
76
+ torch==2.2.0
77
+ tqdm==4.66.4
78
+ transformers==4.41.2
79
+ typer==0.12.3
80
+ typing_extensions==4.12.2
81
+ tzdata==2024.1
82
+ ujson==5.10.0
83
+ urllib3==2.2.2
84
+ uvicorn==0.30.1
85
+ uvloop==0.19.0
86
+ watchfiles==0.22.0
87
+ websockets==11.0.3
88
+ zipp==3.19.2