Spaces:
Runtime error
Runtime error
Commit
Β·
610afda
1
Parent(s):
1cfa651
Pls work
Browse files- .env +2 -1
- Dockerfile.fastapi +1 -1
- {tld/img_examples β img_examples}/a beautiful woman with blonde hair in her 50s_cfg_7_seed_11.png +0 -0
- {tld/img_examples β img_examples}/a cute grey great owl_cfg_8_seed_11.png +0 -0
- {tld/img_examples β img_examples}/a lake in mountains in the fall at sunset_cfg_7_seed_11.png +0 -0
- {tld/img_examples β img_examples}/a woman cyborg with red curly hair, 8k_cfg_9.5_seed_11.png +0 -0
- {tld/img_examples β img_examples}/an aerial view of manhattan, isometric view, as pantinted by mondrian_cfg_7_seed_11.png +0 -0
- {tld/img_examples β img_examples}/isometric view of small japanese village with blooming trees_cfg_7_seed_11.png +0 -0
- {tld/img_examples β img_examples}/painting of a cute fox in a suit in a field of poppies_cfg_8_seed_11.png +0 -0
- {tld/img_examples β img_examples}/painting of a cyberpunk market_cfg_7_seed_11.png +0 -0
- {tld/img_examples β img_examples}/watercolor of a cute cat riding a motorcycle_cfg_7_seed_11.png +0 -0
- og readme.md +1 -1
- old/main.py +37 -0
- requirements.txt +1 -0
- start.sh +1 -2
- tests/client.js +15 -0
- tests/test_api.py +16 -17
- tests/test_diffuser.py +2 -2
- tld/app.py +38 -39
- tld/denoiser.py +1 -1
- tld/diffusion.py +57 -57
- tld/gen_img.py +44 -0
- tld/train.py +2 -2
.env
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
HF_HOME=./models
|
|
|
|
1 |
+
HF_HOME=./models
|
2 |
+
TRANSFORMERS_CACHE=./cache
|
Dockerfile.fastapi
CHANGED
@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir uvicorn gunicorn fastapi pytest ruff pytest-async
|
|
11 |
|
12 |
EXPOSE 80
|
13 |
|
14 |
-
CMD ["uvicorn", "
|
|
|
11 |
|
12 |
EXPOSE 80
|
13 |
|
14 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
|
{tld/img_examples β img_examples}/a beautiful woman with blonde hair in her 50s_cfg_7_seed_11.png
RENAMED
File without changes
|
{tld/img_examples β img_examples}/a cute grey great owl_cfg_8_seed_11.png
RENAMED
File without changes
|
{tld/img_examples β img_examples}/a lake in mountains in the fall at sunset_cfg_7_seed_11.png
RENAMED
File without changes
|
{tld/img_examples β img_examples}/a woman cyborg with red curly hair, 8k_cfg_9.5_seed_11.png
RENAMED
File without changes
|
{tld/img_examples β img_examples}/an aerial view of manhattan, isometric view, as pantinted by mondrian_cfg_7_seed_11.png
RENAMED
File without changes
|
{tld/img_examples β img_examples}/isometric view of small japanese village with blooming trees_cfg_7_seed_11.png
RENAMED
File without changes
|
{tld/img_examples β img_examples}/painting of a cute fox in a suit in a field of poppies_cfg_8_seed_11.png
RENAMED
File without changes
|
{tld/img_examples β img_examples}/painting of a cyberpunk market_cfg_7_seed_11.png
RENAMED
File without changes
|
{tld/img_examples β img_examples}/watercolor of a cute cat riding a motorcycle_cfg_7_seed_11.png
RENAMED
File without changes
|
og readme.md
CHANGED
@@ -71,7 +71,7 @@ If you have your own dataset of URLs + captions, the process to train a model on
|
|
71 |
```python
|
72 |
!wandb login
|
73 |
import os
|
74 |
-
from
|
75 |
from accelerate import notebook_launcher
|
76 |
|
77 |
data_config = DataConfig(latent_path='path/to/image_latents.npy',
|
|
|
71 |
```python
|
72 |
!wandb login
|
73 |
import os
|
74 |
+
from train import main, DataConfig, ModelConfig
|
75 |
from accelerate import notebook_launcher
|
76 |
|
77 |
data_config = DataConfig(latent_path='path/to/image_latents.npy',
|
old/main.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from transformers import CLIPProcessor, CLIPModel
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
# Get the directory of the script
|
6 |
+
script_directory = os.path.dirname(os.path.realpath(__file__))
|
7 |
+
# Specify the directory where the cache will be stored (same folder as the script)
|
8 |
+
cache_directory = os.path.join(script_directory, "cache")
|
9 |
+
# Create the cache directory if it doesn't exist
|
10 |
+
os.makedirs(cache_directory, exist_ok=True)
|
11 |
+
|
12 |
+
# Load the CLIP processor and model
|
13 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
|
14 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
|
15 |
+
|
16 |
+
# Text description to generate image
|
17 |
+
text = "a cat sitting on a table"
|
18 |
+
|
19 |
+
# Tokenize text and get features
|
20 |
+
inputs = clip_processor(text, return_tensors="pt", padding=True)
|
21 |
+
|
22 |
+
# Generate image from text
|
23 |
+
generated_image = clip_model.generate(
|
24 |
+
inputs=inputs.input_ids,
|
25 |
+
attention_mask=inputs.attention_mask,
|
26 |
+
visual_input=None, # We don't provide image inputvi
|
27 |
+
return_tensors="pt" # Return PyTorch tensor
|
28 |
+
)
|
29 |
+
|
30 |
+
# Convert the generated image tensor to a NumPy array
|
31 |
+
generated_image_np = generated_image[0].cpu().numpy()
|
32 |
+
|
33 |
+
# Save the generated image
|
34 |
+
output_image_path = "generated_image.png"
|
35 |
+
Image.fromarray(generated_image_np).save(output_image_path)
|
36 |
+
|
37 |
+
print("Image generated and saved as:", output_image_path)
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ diffusers
|
|
7 |
accelerate
|
8 |
transformers
|
9 |
Pillow
|
|
|
10 |
git+https://github.com/openai/CLIP.git
|
|
|
7 |
accelerate
|
8 |
transformers
|
9 |
Pillow
|
10 |
+
poetry
|
11 |
git+https://github.com/openai/CLIP.git
|
start.sh
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
pip install --upgrade pip
|
2 |
pip install -r requirements.txt
|
3 |
-
poetry install --no-root
|
4 |
|
5 |
-
python
|
|
|
1 |
pip install --upgrade pip
|
2 |
pip install -r requirements.txt
|
|
|
3 |
|
4 |
+
python tld/gen_img.py
|
tests/client.js
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const axios = require("axios");
|
2 |
+
|
3 |
+
const apiUrl = `http://de-fsn-4.halex.gg:25287/api`;
|
4 |
+
|
5 |
+
const postData = {
|
6 |
+
prompt: "Wassup my homie"
|
7 |
+
};
|
8 |
+
|
9 |
+
axios.post(apiUrl, postData)
|
10 |
+
.then(response => {
|
11 |
+
console.log("Response from API:", response.data);
|
12 |
+
})
|
13 |
+
.catch(error => {
|
14 |
+
console.error("Error:", error.message);
|
15 |
+
});
|
tests/test_api.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
from fastapi.testclient import TestClient
|
4 |
-
from
|
5 |
import PIL
|
6 |
from PIL import Image
|
7 |
from io import BytesIO
|
@@ -9,23 +7,24 @@ from io import BytesIO
|
|
9 |
client = TestClient(app)
|
10 |
|
11 |
def test_read_main():
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
|
16 |
|
17 |
def test_generate_image_unauthorized():
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
|
23 |
def test_generate_image_authorized():
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
1 |
from fastapi.testclient import TestClient
|
2 |
+
from app import app
|
3 |
import PIL
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
|
|
7 |
client = TestClient(app)
|
8 |
|
9 |
def test_read_main():
|
10 |
+
response = client.get("/")
|
11 |
+
assert response.status_code == 200
|
12 |
+
assert response.json() == {"message": "Welcome to Image Generator"}
|
13 |
|
14 |
|
15 |
def test_generate_image_unauthorized():
|
16 |
+
response = client.post("/generate-image/", json={})
|
17 |
+
assert response.status_code == 401
|
18 |
+
assert response.json() == {"detail": "Not authenticated"}
|
19 |
|
20 |
|
21 |
def test_generate_image_authorized():
|
22 |
+
response = client.post(
|
23 |
+
"/generate-image/", json={"prompt": "a cute cat"}
|
24 |
+
)
|
25 |
+
assert response.status_code == 200
|
26 |
+
|
27 |
+
image = Image.open(BytesIO(response.content))
|
28 |
+
assert type(image) == PIL.JpegImagePlugin.JpegImageFile
|
29 |
+
|
30 |
+
test_generate_image_authorized()
|
tests/test_diffuser.py
CHANGED
@@ -10,8 +10,8 @@ import torchvision.transforms as transforms
|
|
10 |
import torchvision.utils as vutils
|
11 |
from diffusers import AutoencoderKL
|
12 |
|
13 |
-
from
|
14 |
-
from
|
15 |
from PIL.Image import Image
|
16 |
|
17 |
to_pil = transforms.ToPILImage()
|
|
|
10 |
import torchvision.utils as vutils
|
11 |
from diffusers import AutoencoderKL
|
12 |
|
13 |
+
from denoiser import Denoiser
|
14 |
+
from diffusion import DiffusionGenerator, DiffusionTransformer, LTDConfig
|
15 |
from PIL.Image import Image
|
16 |
|
17 |
to_pil = transforms.ToPILImage()
|
tld/app.py
CHANGED
@@ -4,39 +4,37 @@ from typing import Optional
|
|
4 |
|
5 |
import torch
|
6 |
import torchvision.transforms as transforms
|
7 |
-
from fastapi import
|
8 |
from fastapi.responses import StreamingResponse
|
9 |
from fastapi.security import OAuth2PasswordBearer
|
10 |
from pydantic import BaseModel
|
11 |
|
12 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
15 |
to_pil = transforms.ToPILImage()
|
16 |
-
|
17 |
ltdconfig = LTDConfig()
|
18 |
-
diffusion_transformer = DiffusionTransformer(ltdconfig)
|
19 |
-
|
20 |
app = FastAPI()
|
21 |
|
22 |
-
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
23 |
-
|
24 |
-
|
25 |
-
def validate_token(token: str = Depends(oauth2_scheme)):
|
26 |
-
if token != os.getenv("API_TOKEN"):
|
27 |
-
raise HTTPException(
|
28 |
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
29 |
-
detail="Invalid authentication credentials",
|
30 |
-
headers={"WWW-Authenticate": "Bearer"},
|
31 |
-
)
|
32 |
-
|
33 |
-
|
34 |
class ImageRequest(BaseModel):
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
|
41 |
|
42 |
@app.get("/")
|
@@ -45,23 +43,24 @@ def read_root():
|
|
45 |
|
46 |
|
47 |
@app.post("/generate-image/")
|
48 |
-
async def generate_image(request: ImageRequest
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
|
67 |
# build job to test and deploy the API on a docker image (maybe in Azure?)
|
|
|
4 |
|
5 |
import torch
|
6 |
import torchvision.transforms as transforms
|
7 |
+
from fastapi import FastAPI, HTTPException, status
|
8 |
from fastapi.responses import StreamingResponse
|
9 |
from fastapi.security import OAuth2PasswordBearer
|
10 |
from pydantic import BaseModel
|
11 |
|
12 |
+
from diffusion import DiffusionTransformer, LTDConfig
|
13 |
+
|
14 |
+
# Get the directory of the script
|
15 |
+
script_directory = os.path.dirname(os.path.realpath(__file__))
|
16 |
+
# Specify the directory where the cache will be stored (same folder as the script)
|
17 |
+
cache_directory = os.path.join(script_directory, "cache")
|
18 |
+
home_directory = os.path.join(script_directory, "home")
|
19 |
+
# Create the cache directory if it doesn't exist
|
20 |
+
os.makedirs(cache_directory, exist_ok=True)
|
21 |
+
os.makedirs(home_directory, exist_ok=True)
|
22 |
+
|
23 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_directory
|
24 |
+
os.environ["HF_HOME"] = home_directory
|
25 |
|
26 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
27 |
to_pil = transforms.ToPILImage()
|
|
|
28 |
ltdconfig = LTDConfig()
|
29 |
+
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
|
|
|
30 |
app = FastAPI()
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
class ImageRequest(BaseModel):
|
33 |
+
prompt: str
|
34 |
+
class_guidance: Optional[int] = 6
|
35 |
+
seed: Optional[int] = 11
|
36 |
+
num_imgs: Optional[int] = 1
|
37 |
+
img_size: Optional[int] = 32
|
38 |
|
39 |
|
40 |
@app.get("/")
|
|
|
43 |
|
44 |
|
45 |
@app.post("/generate-image/")
|
46 |
+
async def generate_image(request: ImageRequest):
|
47 |
+
try:
|
48 |
+
img = diffusion_transformer.generate_image_from_text(
|
49 |
+
prompt=request.prompt,
|
50 |
+
class_guidance=request.class_guidance,
|
51 |
+
seed=request.seed,
|
52 |
+
num_imgs=request.num_imgs,
|
53 |
+
img_size=request.img_size,
|
54 |
+
)
|
55 |
+
|
56 |
+
# Convert PIL image to byte stream suitable for HTTP response
|
57 |
+
img_byte_arr = io.BytesIO()
|
58 |
+
img.save(img_byte_arr, format="JPEG")
|
59 |
+
img_byte_arr.seek(0)
|
60 |
+
|
61 |
+
return StreamingResponse(img_byte_arr, media_type="image/jpeg")
|
62 |
+
except Exception as e:
|
63 |
+
raise HTTPException(status_code=500, detail=str(e))
|
64 |
|
65 |
|
66 |
# build job to test and deploy the API on a docker image (maybe in Azure?)
|
tld/denoiser.py
CHANGED
@@ -4,7 +4,7 @@ import torch
|
|
4 |
from einops.layers.torch import Rearrange
|
5 |
from torch import nn
|
6 |
|
7 |
-
from
|
8 |
|
9 |
|
10 |
class DenoiserTransBlock(nn.Module):
|
|
|
4 |
from einops.layers.torch import Rearrange
|
5 |
from torch import nn
|
6 |
|
7 |
+
from transformer_blocks import DecoderBlock, MLPSepConv, SinusoidalEmbedding
|
8 |
|
9 |
|
10 |
class DenoiserTransBlock(nn.Module):
|
tld/diffusion.py
CHANGED
@@ -10,7 +10,7 @@ from diffusers import AutoencoderKL
|
|
10 |
from torch import Tensor
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
-
from
|
14 |
|
15 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
to_pil = transforms.ToPILImage()
|
@@ -124,21 +124,21 @@ class DiffusionGenerator:
|
|
124 |
|
125 |
@dataclass
|
126 |
class LTDConfig:
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
|
143 |
|
144 |
def download_file(url, filename):
|
@@ -151,48 +151,48 @@ def download_file(url, filename):
|
|
151 |
|
152 |
@torch.no_grad()
|
153 |
def encode_text(label, model):
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
|
158 |
|
159 |
class DiffusionTransformer:
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
|
197 |
-
|
198 |
-
|
|
|
10 |
from torch import Tensor
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
+
from denoiser import Denoiser
|
14 |
|
15 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
to_pil = transforms.ToPILImage()
|
|
|
124 |
|
125 |
@dataclass
|
126 |
class LTDConfig:
|
127 |
+
vae_scale_factor: float = 8
|
128 |
+
img_size: int = 32
|
129 |
+
model_dtype: torch.dtype = torch.float32
|
130 |
+
file_url: str = None # = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth"
|
131 |
+
local_filename: str = "state_dict_378000.pth"
|
132 |
+
vae_name: str = "ByteDance/SDXL-Lightning"
|
133 |
+
clip_model_name: str = "ViT-L/14"
|
134 |
+
denoiser: Denoiser = Denoiser(
|
135 |
+
image_size=32,
|
136 |
+
noise_embed_dims=256,
|
137 |
+
patch_size=2,
|
138 |
+
embed_dim=256,
|
139 |
+
dropout=0,
|
140 |
+
n_layers=4,
|
141 |
+
)
|
142 |
|
143 |
|
144 |
def download_file(url, filename):
|
|
|
151 |
|
152 |
@torch.no_grad()
|
153 |
def encode_text(label, model):
|
154 |
+
text_tokens = clip.tokenize(label, truncate=True).to(device)
|
155 |
+
text_encoding = model.encode_text(text_tokens)
|
156 |
+
return text_encoding.cpu()
|
157 |
|
158 |
|
159 |
class DiffusionTransformer:
|
160 |
+
def __init__(self, config: LTDConfig):
|
161 |
+
denoiser = config.denoiser.to(config.model_dtype)
|
162 |
+
|
163 |
+
if config.file_url is not None:
|
164 |
+
print(f"Downloading model from {config.file_url}")
|
165 |
+
download_file(config.file_url, config.local_filename)
|
166 |
+
state_dict = torch.load(config.local_filename, map_location=torch.device("cpu"))
|
167 |
+
denoiser.load_state_dict(state_dict)
|
168 |
+
|
169 |
+
denoiser = denoiser.to(device)
|
170 |
+
|
171 |
+
vae = AutoencoderKL.from_pretrained(config.vae_name, torch_dtype=config.model_dtype).to(device)
|
172 |
+
|
173 |
+
self.clip_model, preprocess = clip.load(config.clip_model_name)
|
174 |
+
self.clip_model = self.clip_model.to(device)
|
175 |
+
|
176 |
+
self.diffuser = DiffusionGenerator(denoiser, vae, device, config.model_dtype)
|
177 |
+
|
178 |
+
def generate_image_from_text(
|
179 |
+
self, prompt: str, class_guidance=6, seed=11, num_imgs=1, img_size=32, n_iter=15
|
180 |
+
):
|
181 |
+
nrow = int(np.sqrt(num_imgs))
|
182 |
+
|
183 |
+
cur_prompts = [prompt] * num_imgs
|
184 |
+
labels = encode_text(cur_prompts, self.clip_model)
|
185 |
+
out, out_latent = self.diffuser.generate(
|
186 |
+
labels=labels,
|
187 |
+
num_imgs=num_imgs,
|
188 |
+
class_guidance=class_guidance,
|
189 |
+
seed=seed,
|
190 |
+
n_iter=n_iter,
|
191 |
+
exponent=1,
|
192 |
+
scale_factor=8,
|
193 |
+
sharp_f=0,
|
194 |
+
bright_f=0,
|
195 |
+
)
|
196 |
|
197 |
+
out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
|
198 |
+
return out
|
tld/gen_img.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import asyncio
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torchvision import utils as vutils
|
10 |
+
|
11 |
+
from diffusion import DiffusionTransformer, LTDConfig
|
12 |
+
|
13 |
+
# Get the directory of the script
|
14 |
+
script_directory = os.path.dirname(os.path.realpath(__file__))
|
15 |
+
# Specify the directory where the cache will be stored (same folder as the script)
|
16 |
+
cache_directory = os.path.join(script_directory, "cache")
|
17 |
+
home_directory = os.path.join(script_directory, "home")
|
18 |
+
# Create the cache directory if it doesn't exist
|
19 |
+
os.makedirs(cache_directory, exist_ok=True)
|
20 |
+
os.makedirs(home_directory, exist_ok=True)
|
21 |
+
|
22 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_directory
|
23 |
+
os.environ["HF_HOME"] = home_directory
|
24 |
+
|
25 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
26 |
+
to_pil = transforms.ToPILImage()
|
27 |
+
ltdconfig = LTDConfig()
|
28 |
+
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
|
29 |
+
|
30 |
+
async def generate_image(prompt):
|
31 |
+
try:
|
32 |
+
img = diffusion_transformer.generate_image_from_text(
|
33 |
+
prompt=prompt,
|
34 |
+
class_guidance=6,
|
35 |
+
seed=11,
|
36 |
+
num_imgs=1,
|
37 |
+
img_size=32,
|
38 |
+
)
|
39 |
+
|
40 |
+
img.save("generated_img.png")
|
41 |
+
except Exception as e:
|
42 |
+
print(e)
|
43 |
+
|
44 |
+
asyncio.run(generate_image("a cute cat"))
|
tld/train.py
CHANGED
@@ -15,8 +15,8 @@ from torch import Tensor, nn
|
|
15 |
from torch.utils.data import DataLoader, TensorDataset
|
16 |
from tqdm import tqdm
|
17 |
|
18 |
-
from
|
19 |
-
from
|
20 |
|
21 |
|
22 |
def eval_gen(diffuser: DiffusionGenerator, labels: Tensor) -> Image:
|
|
|
15 |
from torch.utils.data import DataLoader, TensorDataset
|
16 |
from tqdm import tqdm
|
17 |
|
18 |
+
from denoiser import Denoiser
|
19 |
+
from diffusion import DiffusionGenerator
|
20 |
|
21 |
|
22 |
def eval_gen(diffuser: DiffusionGenerator, labels: Tensor) -> Image:
|