from img2art_search.models.predict import predict from img2art_search.models.train import fine_tune_vit from img2art_search.models.compute_embeddings import create_gallery_embeddings import gradio as gr import argparse def make_interface(): interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Gallery(label="Most similar images", height=256 * 3), live=True, ) interface.launch(share=True) def train(epochs, batch_size): fine_tune_vit(epochs, batch_size) def create_gallery(gallery_path): create_gallery_embeddings(gallery_path) def main(): parser = argparse.ArgumentParser(description="Train or infer the ViT model for image-to-art search.") subparsers = parser.add_subparsers(dest="command") # Subparser for training train_parser = subparsers.add_parser("train", help="Fine-tune the ViT model") train_parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs") train_parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training") # Subparser for inference _ = subparsers.add_parser("interface", help="Perform image-to-art search using the fine-tuned model") create_gallery_parser = subparsers.add_parser("gallery", help="Create new gallery from a path") create_gallery_parser.add_argument("--gallery_path", type=str, default="data/wikiart") args = parser.parse_args() if args.command == "train": train(args.epochs, args.batch_size) elif args.command == "interface": make_interface() elif args.command == "gallery": create_gallery(args.gallery_path) else: parser.print_help() if __name__ == "__main__": main()