picpilot-server / scripts /sdxl_lora_inference.py
VikramSingh178's picture
Add new endpoints for product diffusion API and SDXL-LoRA inference
f1b2ef2
from wandb.integration.diffusers import autolog
from diffusers import DiffusionPipeline
import torch
from config import PROJECT_NAME
autolog(init=dict(project=PROJECT_NAME))
class SDXLLoraInference:
"""
Class for running inference using the SDXL-LoRA model to generate stunning product photographs.
Args:
num_inference_steps (int): The number of inference steps to perform.
guidance_scale (float): The scale factor for guidance during inference.
"""
def __init__(self, num_inference_steps: int, guidance_scale: float) -> None:
self.model_path = "VikramSingh178/sdxl-lora-finetune-product-caption"
self.pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
self.pipe.to("cuda")
self.pipe.load_lora_weights(self.model_path)
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
def run_inference(self, prompt):
"""
Runs inference using the SDXL-LoRA model to generate a stunning product photograph.
Args:
prompt: The input prompt for generating the product photograph.
Returns:
images: The generated product photograph(s).
"""
prompt = prompt
images = self.pipe(prompt, num_inference_steps=self.num_inference_steps, guidance_scale=self.guidance_scale).images
return images
inference = SDXLLoraInference(num_inference_steps=100, guidance_scale=2.5)
inference.run_inference(prompt= "A stunning 4k Shot of a Balenciaga X Anime Hoodie with a person wearing it in a party" )