BeveledCube commited on
Commit
610afda
Β·
1 Parent(s): 1cfa651
.env CHANGED
@@ -1 +1,2 @@
1
- HF_HOME=./models
 
 
1
+ HF_HOME=./models
2
+ TRANSFORMERS_CACHE=./cache
Dockerfile.fastapi CHANGED
@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir uvicorn gunicorn fastapi pytest ruff pytest-async
11
 
12
  EXPOSE 80
13
 
14
- CMD ["uvicorn", "tld.app:app", "--host", "0.0.0.0", "--port", "80"]
 
11
 
12
  EXPOSE 80
13
 
14
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
{tld/img_examples β†’ img_examples}/a beautiful woman with blonde hair in her 50s_cfg_7_seed_11.png RENAMED
File without changes
{tld/img_examples β†’ img_examples}/a cute grey great owl_cfg_8_seed_11.png RENAMED
File without changes
{tld/img_examples β†’ img_examples}/a lake in mountains in the fall at sunset_cfg_7_seed_11.png RENAMED
File without changes
{tld/img_examples β†’ img_examples}/a woman cyborg with red curly hair, 8k_cfg_9.5_seed_11.png RENAMED
File without changes
{tld/img_examples β†’ img_examples}/an aerial view of manhattan, isometric view, as pantinted by mondrian_cfg_7_seed_11.png RENAMED
File without changes
{tld/img_examples β†’ img_examples}/isometric view of small japanese village with blooming trees_cfg_7_seed_11.png RENAMED
File without changes
{tld/img_examples β†’ img_examples}/painting of a cute fox in a suit in a field of poppies_cfg_8_seed_11.png RENAMED
File without changes
{tld/img_examples β†’ img_examples}/painting of a cyberpunk market_cfg_7_seed_11.png RENAMED
File without changes
{tld/img_examples β†’ img_examples}/watercolor of a cute cat riding a motorcycle_cfg_7_seed_11.png RENAMED
File without changes
og readme.md CHANGED
@@ -71,7 +71,7 @@ If you have your own dataset of URLs + captions, the process to train a model on
71
  ```python
72
  !wandb login
73
  import os
74
- from tld.train import main, DataConfig, ModelConfig
75
  from accelerate import notebook_launcher
76
 
77
  data_config = DataConfig(latent_path='path/to/image_latents.npy',
 
71
  ```python
72
  !wandb login
73
  import os
74
+ from train import main, DataConfig, ModelConfig
75
  from accelerate import notebook_launcher
76
 
77
  data_config = DataConfig(latent_path='path/to/image_latents.npy',
old/main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ from PIL import Image
4
+
5
+ # Get the directory of the script
6
+ script_directory = os.path.dirname(os.path.realpath(__file__))
7
+ # Specify the directory where the cache will be stored (same folder as the script)
8
+ cache_directory = os.path.join(script_directory, "cache")
9
+ # Create the cache directory if it doesn't exist
10
+ os.makedirs(cache_directory, exist_ok=True)
11
+
12
+ # Load the CLIP processor and model
13
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
14
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
15
+
16
+ # Text description to generate image
17
+ text = "a cat sitting on a table"
18
+
19
+ # Tokenize text and get features
20
+ inputs = clip_processor(text, return_tensors="pt", padding=True)
21
+
22
+ # Generate image from text
23
+ generated_image = clip_model.generate(
24
+ inputs=inputs.input_ids,
25
+ attention_mask=inputs.attention_mask,
26
+ visual_input=None, # We don't provide image inputvi
27
+ return_tensors="pt" # Return PyTorch tensor
28
+ )
29
+
30
+ # Convert the generated image tensor to a NumPy array
31
+ generated_image_np = generated_image[0].cpu().numpy()
32
+
33
+ # Save the generated image
34
+ output_image_path = "generated_image.png"
35
+ Image.fromarray(generated_image_np).save(output_image_path)
36
+
37
+ print("Image generated and saved as:", output_image_path)
requirements.txt CHANGED
@@ -7,4 +7,5 @@ diffusers
7
  accelerate
8
  transformers
9
  Pillow
 
10
  git+https://github.com/openai/CLIP.git
 
7
  accelerate
8
  transformers
9
  Pillow
10
+ poetry
11
  git+https://github.com/openai/CLIP.git
start.sh CHANGED
@@ -1,5 +1,4 @@
1
  pip install --upgrade pip
2
  pip install -r requirements.txt
3
- poetry install --no-root
4
 
5
- python main.py
 
1
  pip install --upgrade pip
2
  pip install -r requirements.txt
 
3
 
4
+ python tld/gen_img.py
tests/client.js ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const axios = require("axios");
2
+
3
+ const apiUrl = `http://de-fsn-4.halex.gg:25287/api`;
4
+
5
+ const postData = {
6
+ prompt: "Wassup my homie"
7
+ };
8
+
9
+ axios.post(apiUrl, postData)
10
+ .then(response => {
11
+ console.log("Response from API:", response.data);
12
+ })
13
+ .catch(error => {
14
+ console.error("Error:", error.message);
15
+ });
tests/test_api.py CHANGED
@@ -1,7 +1,5 @@
1
- import os
2
-
3
  from fastapi.testclient import TestClient
4
- from tld.app import app
5
  import PIL
6
  from PIL import Image
7
  from io import BytesIO
@@ -9,23 +7,24 @@ from io import BytesIO
9
  client = TestClient(app)
10
 
11
  def test_read_main():
12
- response = client.get("/")
13
- assert response.status_code == 200
14
- assert response.json() == {"message": "Welcome to Image Generator"}
15
 
16
 
17
  def test_generate_image_unauthorized():
18
- response = client.post("/generate-image/", json={})
19
- assert response.status_code == 401
20
- assert response.json() == {"detail": "Not authenticated"}
21
 
22
 
23
  def test_generate_image_authorized():
24
- api_token = os.getenv("API_TOKEN")
25
- response = client.post(
26
- "/generate-image/", json={"prompt": "a cute cat"}, headers={"Authorization": f"Bearer {api_token}"}
27
- )
28
- assert response.status_code == 200
29
-
30
- image = Image.open(BytesIO(response.content))
31
- assert type(image) == PIL.JpegImagePlugin.JpegImageFile
 
 
 
 
1
  from fastapi.testclient import TestClient
2
+ from app import app
3
  import PIL
4
  from PIL import Image
5
  from io import BytesIO
 
7
  client = TestClient(app)
8
 
9
  def test_read_main():
10
+ response = client.get("/")
11
+ assert response.status_code == 200
12
+ assert response.json() == {"message": "Welcome to Image Generator"}
13
 
14
 
15
  def test_generate_image_unauthorized():
16
+ response = client.post("/generate-image/", json={})
17
+ assert response.status_code == 401
18
+ assert response.json() == {"detail": "Not authenticated"}
19
 
20
 
21
  def test_generate_image_authorized():
22
+ response = client.post(
23
+ "/generate-image/", json={"prompt": "a cute cat"}
24
+ )
25
+ assert response.status_code == 200
26
+
27
+ image = Image.open(BytesIO(response.content))
28
+ assert type(image) == PIL.JpegImagePlugin.JpegImageFile
29
+
30
+ test_generate_image_authorized()
tests/test_diffuser.py CHANGED
@@ -10,8 +10,8 @@ import torchvision.transforms as transforms
10
  import torchvision.utils as vutils
11
  from diffusers import AutoencoderKL
12
 
13
- from tld.denoiser import Denoiser
14
- from tld.diffusion import DiffusionGenerator, DiffusionTransformer, LTDConfig
15
  from PIL.Image import Image
16
 
17
  to_pil = transforms.ToPILImage()
 
10
  import torchvision.utils as vutils
11
  from diffusers import AutoencoderKL
12
 
13
+ from denoiser import Denoiser
14
+ from diffusion import DiffusionGenerator, DiffusionTransformer, LTDConfig
15
  from PIL.Image import Image
16
 
17
  to_pil = transforms.ToPILImage()
tld/app.py CHANGED
@@ -4,39 +4,37 @@ from typing import Optional
4
 
5
  import torch
6
  import torchvision.transforms as transforms
7
- from fastapi import Depends, FastAPI, HTTPException, status
8
  from fastapi.responses import StreamingResponse
9
  from fastapi.security import OAuth2PasswordBearer
10
  from pydantic import BaseModel
11
 
12
- from tld.diffusion import DiffusionTransformer, LTDConfig
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
  to_pil = transforms.ToPILImage()
16
-
17
  ltdconfig = LTDConfig()
18
- diffusion_transformer = DiffusionTransformer(ltdconfig)
19
-
20
  app = FastAPI()
21
 
22
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
23
-
24
-
25
- def validate_token(token: str = Depends(oauth2_scheme)):
26
- if token != os.getenv("API_TOKEN"):
27
- raise HTTPException(
28
- status_code=status.HTTP_401_UNAUTHORIZED,
29
- detail="Invalid authentication credentials",
30
- headers={"WWW-Authenticate": "Bearer"},
31
- )
32
-
33
-
34
  class ImageRequest(BaseModel):
35
- prompt: str
36
- class_guidance: Optional[int] = 6
37
- seed: Optional[int] = 11
38
- num_imgs: Optional[int] = 1
39
- img_size: Optional[int] = 32
40
 
41
 
42
  @app.get("/")
@@ -45,23 +43,24 @@ def read_root():
45
 
46
 
47
  @app.post("/generate-image/")
48
- async def generate_image(request: ImageRequest, token: str = Depends(validate_token)):
49
- try:
50
- img = diffusion_transformer.generate_image_from_text(
51
- prompt=request.prompt,
52
- class_guidance=request.class_guidance,
53
- seed=request.seed,
54
- num_imgs=request.num_imgs,
55
- img_size=request.img_size,
56
- )
57
- # Convert PIL image to byte stream suitable for HTTP response
58
- img_byte_arr = io.BytesIO()
59
- img.save(img_byte_arr, format="JPEG")
60
- img_byte_arr.seek(0)
61
-
62
- return StreamingResponse(img_byte_arr, media_type="image/jpeg")
63
- except Exception as e:
64
- raise HTTPException(status_code=500, detail=str(e))
 
65
 
66
 
67
  # build job to test and deploy the API on a docker image (maybe in Azure?)
 
4
 
5
  import torch
6
  import torchvision.transforms as transforms
7
+ from fastapi import FastAPI, HTTPException, status
8
  from fastapi.responses import StreamingResponse
9
  from fastapi.security import OAuth2PasswordBearer
10
  from pydantic import BaseModel
11
 
12
+ from diffusion import DiffusionTransformer, LTDConfig
13
+
14
+ # Get the directory of the script
15
+ script_directory = os.path.dirname(os.path.realpath(__file__))
16
+ # Specify the directory where the cache will be stored (same folder as the script)
17
+ cache_directory = os.path.join(script_directory, "cache")
18
+ home_directory = os.path.join(script_directory, "home")
19
+ # Create the cache directory if it doesn't exist
20
+ os.makedirs(cache_directory, exist_ok=True)
21
+ os.makedirs(home_directory, exist_ok=True)
22
+
23
+ os.environ["TRANSFORMERS_CACHE"] = cache_directory
24
+ os.environ["HF_HOME"] = home_directory
25
 
26
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
  to_pil = transforms.ToPILImage()
 
28
  ltdconfig = LTDConfig()
29
+ diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
 
30
  app = FastAPI()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class ImageRequest(BaseModel):
33
+ prompt: str
34
+ class_guidance: Optional[int] = 6
35
+ seed: Optional[int] = 11
36
+ num_imgs: Optional[int] = 1
37
+ img_size: Optional[int] = 32
38
 
39
 
40
  @app.get("/")
 
43
 
44
 
45
  @app.post("/generate-image/")
46
+ async def generate_image(request: ImageRequest):
47
+ try:
48
+ img = diffusion_transformer.generate_image_from_text(
49
+ prompt=request.prompt,
50
+ class_guidance=request.class_guidance,
51
+ seed=request.seed,
52
+ num_imgs=request.num_imgs,
53
+ img_size=request.img_size,
54
+ )
55
+
56
+ # Convert PIL image to byte stream suitable for HTTP response
57
+ img_byte_arr = io.BytesIO()
58
+ img.save(img_byte_arr, format="JPEG")
59
+ img_byte_arr.seek(0)
60
+
61
+ return StreamingResponse(img_byte_arr, media_type="image/jpeg")
62
+ except Exception as e:
63
+ raise HTTPException(status_code=500, detail=str(e))
64
 
65
 
66
  # build job to test and deploy the API on a docker image (maybe in Azure?)
tld/denoiser.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from einops.layers.torch import Rearrange
5
  from torch import nn
6
 
7
- from tld.transformer_blocks import DecoderBlock, MLPSepConv, SinusoidalEmbedding
8
 
9
 
10
  class DenoiserTransBlock(nn.Module):
 
4
  from einops.layers.torch import Rearrange
5
  from torch import nn
6
 
7
+ from transformer_blocks import DecoderBlock, MLPSepConv, SinusoidalEmbedding
8
 
9
 
10
  class DenoiserTransBlock(nn.Module):
tld/diffusion.py CHANGED
@@ -10,7 +10,7 @@ from diffusers import AutoencoderKL
10
  from torch import Tensor
11
  from tqdm import tqdm
12
 
13
- from tld.denoiser import Denoiser
14
 
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
  to_pil = transforms.ToPILImage()
@@ -124,21 +124,21 @@ class DiffusionGenerator:
124
 
125
  @dataclass
126
  class LTDConfig:
127
- vae_scale_factor: float = 8
128
- img_size: int = 32
129
- model_dtype: torch.dtype = torch.float32
130
- file_url: str = None # = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth"
131
- local_filename: str = "state_dict_378000.pth"
132
- vae_name: str = "madebyollin/sdxl-vae-fp16-fix"
133
- clip_model_name: str = "ViT-L/14"
134
- denoiser: Denoiser = Denoiser(
135
- image_size=32,
136
- noise_embed_dims=256,
137
- patch_size=2,
138
- embed_dim=256,
139
- dropout=0,
140
- n_layers=4,
141
- )
142
 
143
 
144
  def download_file(url, filename):
@@ -151,48 +151,48 @@ def download_file(url, filename):
151
 
152
  @torch.no_grad()
153
  def encode_text(label, model):
154
- text_tokens = clip.tokenize(label, truncate=True).to(device)
155
- text_encoding = model.encode_text(text_tokens)
156
- return text_encoding.cpu()
157
 
158
 
159
  class DiffusionTransformer:
160
- def __init__(self, config: LTDConfig):
161
- denoiser = config.denoiser.to(config.model_dtype)
162
-
163
- if config.file_url is not None:
164
- print(f"Downloading model from {config.file_url}")
165
- download_file(config.file_url, config.local_filename)
166
- state_dict = torch.load(config.local_filename, map_location=torch.device("cpu"))
167
- denoiser.load_state_dict(state_dict)
168
-
169
- denoiser = denoiser.to(device)
170
-
171
- vae = AutoencoderKL.from_pretrained(config.vae_name, torch_dtype=config.model_dtype).to(device)
172
-
173
- self.clip_model, preprocess = clip.load(config.clip_model_name)
174
- self.clip_model = self.clip_model.to(device)
175
-
176
- self.diffuser = DiffusionGenerator(denoiser, vae, device, config.model_dtype)
177
-
178
- def generate_image_from_text(
179
- self, prompt: str, class_guidance=6, seed=11, num_imgs=1, img_size=32, n_iter=15
180
- ):
181
- nrow = int(np.sqrt(num_imgs))
182
-
183
- cur_prompts = [prompt] * num_imgs
184
- labels = encode_text(cur_prompts, self.clip_model)
185
- out, out_latent = self.diffuser.generate(
186
- labels=labels,
187
- num_imgs=num_imgs,
188
- class_guidance=class_guidance,
189
- seed=seed,
190
- n_iter=n_iter,
191
- exponent=1,
192
- scale_factor=8,
193
- sharp_f=0,
194
- bright_f=0,
195
- )
196
 
197
- out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
198
- return out
 
10
  from torch import Tensor
11
  from tqdm import tqdm
12
 
13
+ from denoiser import Denoiser
14
 
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
  to_pil = transforms.ToPILImage()
 
124
 
125
  @dataclass
126
  class LTDConfig:
127
+ vae_scale_factor: float = 8
128
+ img_size: int = 32
129
+ model_dtype: torch.dtype = torch.float32
130
+ file_url: str = None # = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth"
131
+ local_filename: str = "state_dict_378000.pth"
132
+ vae_name: str = "ByteDance/SDXL-Lightning"
133
+ clip_model_name: str = "ViT-L/14"
134
+ denoiser: Denoiser = Denoiser(
135
+ image_size=32,
136
+ noise_embed_dims=256,
137
+ patch_size=2,
138
+ embed_dim=256,
139
+ dropout=0,
140
+ n_layers=4,
141
+ )
142
 
143
 
144
  def download_file(url, filename):
 
151
 
152
  @torch.no_grad()
153
  def encode_text(label, model):
154
+ text_tokens = clip.tokenize(label, truncate=True).to(device)
155
+ text_encoding = model.encode_text(text_tokens)
156
+ return text_encoding.cpu()
157
 
158
 
159
  class DiffusionTransformer:
160
+ def __init__(self, config: LTDConfig):
161
+ denoiser = config.denoiser.to(config.model_dtype)
162
+
163
+ if config.file_url is not None:
164
+ print(f"Downloading model from {config.file_url}")
165
+ download_file(config.file_url, config.local_filename)
166
+ state_dict = torch.load(config.local_filename, map_location=torch.device("cpu"))
167
+ denoiser.load_state_dict(state_dict)
168
+
169
+ denoiser = denoiser.to(device)
170
+
171
+ vae = AutoencoderKL.from_pretrained(config.vae_name, torch_dtype=config.model_dtype).to(device)
172
+
173
+ self.clip_model, preprocess = clip.load(config.clip_model_name)
174
+ self.clip_model = self.clip_model.to(device)
175
+
176
+ self.diffuser = DiffusionGenerator(denoiser, vae, device, config.model_dtype)
177
+
178
+ def generate_image_from_text(
179
+ self, prompt: str, class_guidance=6, seed=11, num_imgs=1, img_size=32, n_iter=15
180
+ ):
181
+ nrow = int(np.sqrt(num_imgs))
182
+
183
+ cur_prompts = [prompt] * num_imgs
184
+ labels = encode_text(cur_prompts, self.clip_model)
185
+ out, out_latent = self.diffuser.generate(
186
+ labels=labels,
187
+ num_imgs=num_imgs,
188
+ class_guidance=class_guidance,
189
+ seed=seed,
190
+ n_iter=n_iter,
191
+ exponent=1,
192
+ scale_factor=8,
193
+ sharp_f=0,
194
+ bright_f=0,
195
+ )
196
 
197
+ out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
198
+ return out
tld/gen_img.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import asyncio
3
+ import os
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from torchvision import utils as vutils
10
+
11
+ from diffusion import DiffusionTransformer, LTDConfig
12
+
13
+ # Get the directory of the script
14
+ script_directory = os.path.dirname(os.path.realpath(__file__))
15
+ # Specify the directory where the cache will be stored (same folder as the script)
16
+ cache_directory = os.path.join(script_directory, "cache")
17
+ home_directory = os.path.join(script_directory, "home")
18
+ # Create the cache directory if it doesn't exist
19
+ os.makedirs(cache_directory, exist_ok=True)
20
+ os.makedirs(home_directory, exist_ok=True)
21
+
22
+ os.environ["TRANSFORMERS_CACHE"] = cache_directory
23
+ os.environ["HF_HOME"] = home_directory
24
+
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+ to_pil = transforms.ToPILImage()
27
+ ltdconfig = LTDConfig()
28
+ diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
29
+
30
+ async def generate_image(prompt):
31
+ try:
32
+ img = diffusion_transformer.generate_image_from_text(
33
+ prompt=prompt,
34
+ class_guidance=6,
35
+ seed=11,
36
+ num_imgs=1,
37
+ img_size=32,
38
+ )
39
+
40
+ img.save("generated_img.png")
41
+ except Exception as e:
42
+ print(e)
43
+
44
+ asyncio.run(generate_image("a cute cat"))
tld/train.py CHANGED
@@ -15,8 +15,8 @@ from torch import Tensor, nn
15
  from torch.utils.data import DataLoader, TensorDataset
16
  from tqdm import tqdm
17
 
18
- from tld.denoiser import Denoiser
19
- from tld.diffusion import DiffusionGenerator
20
 
21
 
22
  def eval_gen(diffuser: DiffusionGenerator, labels: Tensor) -> Image:
 
15
  from torch.utils.data import DataLoader, TensorDataset
16
  from tqdm import tqdm
17
 
18
+ from denoiser import Denoiser
19
+ from diffusion import DiffusionGenerator
20
 
21
 
22
  def eval_gen(diffuser: DiffusionGenerator, labels: Tensor) -> Image: