import os import csv import torch from diffusers import AutoPipelineForText2Image def load_prompts(path): if os.path.basename(path) == 'ViLG-300.csv': def csv_to_dict(file_path): result_dict = {} with open(file_path, 'r', encoding='utf-8') as csv_file: csv_reader = csv.DictReader(csv_file, delimiter=',') for row in csv_reader: prompt = row['\ufeffPrompt'] text = row['文本'] category = row['类别'] source = row['来源'] result_dict[prompt] = {'prompt': prompt, 'text': text, 'category': category, 'source': source} return result_dict data = csv_to_dict(path).keys() else: return NotImplementedError return data def main( model_id="runwayml/stable-diffusion-v1-5", prompt_path="assets/ViLG-300.csv", save_path=None, dtype='fp16', variant=None, ): if save_path is None: save_path = os.path.join('saved', model_id.replace('/', '_')) os.makedirs(save_path, exist_ok=True) prompts = load_prompts(prompt_path) pipeline = AutoPipelineForText2Image.from_pretrained( model_id, variant=variant, torch_dtype=torch.float32 if dtype == 'fp32' else torch.float16 ) pipeline.to(device='cuda') pipeline.safety_checker = None for i, prompt in enumerate(prompts): print(f'{i}|{len(prompts)}: {prompt}') image = pipeline(prompt).images[0] image.save(os.path.join(save_path, f'{i}.jpg')) if __name__ == '__main__': import fire fire.Fire(main)