File size: 3,122 Bytes
2efc095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
'''
python colorflow_cli.py \
  --input_image ./input.jpg \
  --reference_images ./ref1.jpg ./ref2.jpg \
  --output_dir ./results \
  --input_style Sketch \
  --resolution 640x640 \
  --seed 123 \
  --num_inference_steps 20
'''

# colorflow_cli.py
from app_func import *
import argparse
import torch
from PIL import Image
import os
import logging

# 原文件中的必要导入和函数定义(需保留原文件中的核心逻辑)
# ... [保留原文件中的模型加载、extract_line_image、colorize_image等函数] ...

def parse_args():
    parser = argparse.ArgumentParser(description="ColorFlow命令行图像上色工具")
    parser.add_argument("--input_image", type=str, required=True, help="输入图像路径")
    parser.add_argument("--reference_images", type=str, nargs='+', required=True, help="参考图像路径列表")
    parser.add_argument("--output_dir", type=str, default="./output", help="输出目录")
    parser.add_argument("--input_style", type=str, default="GrayImage(ScreenStyle)", 
                        choices=["GrayImage(ScreenStyle)", "Sketch"], help="输入样式类型")
    parser.add_argument("--resolution", type=str, default="640x640", 
                        choices=["640x640", "512x800", "800x512"], help="分辨率设置")
    parser.add_argument("--seed", type=int, default=0, help="随机种子")
    parser.add_argument("--num_inference_steps", type=int, default=10, help="推理步数")
    return parser.parse_args()

def save_image(image: Image.Image, path: str, format: str = "PNG") -> None:
    """安全保存图像并处理异常"""
    try:
        image.save(path, format=format)
        logging.info(f"成功保存图像至: {path}")
    except Exception as e:
        logging.error(f"保存图像失败: {str(e)}")
        raise

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 初始化模型
    global cur_input_style, pipeline, MultiResNetModel
    cur_input_style = None
    load_ckpt(args.input_style)
    
    # 预处理输入图像
    input_img = Image.open(args.input_image).convert("RGB")
    input_context, extracted_line, _ = extract_line_image(input_img, args.input_style, args.resolution)
    
    # 执行颜色化并获取全部结果
    high_res_img, up_img, raw_output, preprocessed_bw = colorize_image(
        VAE_input=extracted_line,
        input_context=input_context,
        reference_images=args.reference_images,
        resolution=args.resolution,
        seed=args.seed,
        input_style=args.input_style,
        num_inference_steps=args.num_inference_steps
    )
    
    # 保存所有结果
    save_image(high_res_img, os.path.join(args.output_dir, "colorized_result.png"))
    save_image(up_img, os.path.join(args.output_dir, "upsampled_intermediate.png"))
    save_image(raw_output, os.path.join(args.output_dir, "raw_generated_output.png"))
    save_image(preprocessed_bw, os.path.join(args.output_dir, "preprocessed_bw.png"))

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    main()