henryu commited on
Commit
458b51c
·
1 Parent(s): 1f04f88

Create clip_superior

Browse files
Files changed (1) hide show
  1. clip_superior +99 -0
clip_superior ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import torch
4
+ from clip_interrogator import Config, Interrogator, list_caption_models, list_clip_models
5
+
6
+ try:
7
+ import gradio as gr
8
+ except ImportError:
9
+ print("Gradio is not installed, please install it with 'pip install gradio'")
10
+ exit(1)
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM")
14
+ parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
15
+ args = parser.parse_args()
16
+
17
+ if not torch.cuda.is_available():
18
+ print("CUDA is not available, using CPU. Warning: this will be very slow!")
19
+
20
+ config = Config(cache_path="cache")
21
+ if args.lowvram:
22
+ config.apply_low_vram_defaults()
23
+ ci = Interrogator(config)
24
+
25
+ def image_analysis(image, clip_model_name):
26
+ if clip_model_name != ci.config.clip_model_name:
27
+ ci.config.clip_model_name = clip_model_name
28
+ ci.load_clip_model()
29
+
30
+ image = image.convert('RGB')
31
+ image_features = ci.image_to_features(image)
32
+
33
+ top_mediums = ci.mediums.rank(image_features, 5)
34
+ top_artists = ci.artists.rank(image_features, 5)
35
+ top_movements = ci.movements.rank(image_features, 5)
36
+ top_trendings = ci.trendings.rank(image_features, 5)
37
+ top_flavors = ci.flavors.rank(image_features, 5)
38
+
39
+ medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
40
+ artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
41
+ movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
42
+ trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
43
+ flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
44
+
45
+ return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
46
+
47
+ def image_to_prompt(image, mode, clip_model_name, blip_model_name):
48
+ if blip_model_name != ci.config.caption_model_name:
49
+ ci.config.caption_model_name = blip_model_name
50
+ ci.load_caption_model()
51
+
52
+ if clip_model_name != ci.config.clip_model_name:
53
+ ci.config.clip_model_name = clip_model_name
54
+ ci.load_clip_model()
55
+
56
+ image = image.convert('RGB')
57
+ if mode == 'best':
58
+ return ci.interrogate(image)
59
+ elif mode == 'classic':
60
+ return ci.interrogate_classic(image)
61
+ elif mode == 'fast':
62
+ return ci.interrogate_fast(image)
63
+ elif mode == 'negative':
64
+ return ci.interrogate_negative(image)
65
+
66
+ def prompt_tab():
67
+ with gr.Column():
68
+ with gr.Row():
69
+ image = gr.Image(type='pil', label="Image")
70
+ with gr.Column():
71
+ mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
72
+ clip_model = gr.Dropdown(list_clip_models(), value=ci.config.clip_model_name, label='CLIP Model')
73
+ blip_model = gr.Dropdown(list_caption_models(), value=ci.config.caption_model_name, label='Caption Model')
74
+ prompt = gr.Textbox(label="Prompt")
75
+ button = gr.Button("Generate prompt")
76
+ button.click(image_to_prompt, inputs=[image, mode, clip_model, blip_model], outputs=prompt)
77
+
78
+ def analyze_tab():
79
+ with gr.Column():
80
+ with gr.Row():
81
+ image = gr.Image(type='pil', label="Image")
82
+ model = gr.Dropdown(list_clip_models(), value='ViT-L-14/openai', label='CLIP Model')
83
+ with gr.Row():
84
+ medium = gr.Label(label="Medium", num_top_classes=5)
85
+ artist = gr.Label(label="Artist", num_top_classes=5)
86
+ movement = gr.Label(label="Movement", num_top_classes=5)
87
+ trending = gr.Label(label="Trending", num_top_classes=5)
88
+ flavor = gr.Label(label="Flavor", num_top_classes=5)
89
+ button = gr.Button("Analyze")
90
+ button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
91
+
92
+ with gr.Blocks() as ui:
93
+ gr.Markdown("# <center>🕵️‍♂️ CLIP Interrogator 🕵️‍♂️</center>")
94
+ with gr.Tab("Prompt"):
95
+ prompt_tab()
96
+ with gr.Tab("Analyze"):
97
+ analyze_tab()
98
+
99
+ ui.launch(show_api=False, debug=True, share=args.share)