File size: 2,052 Bytes
bb0a0a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import torch
import gradio as gr

from inference import CityClassifierMultiModelPipeline, get_model_path

TOKEN  = os.environ.get("HFS_TOKEN")
HFREPO = "City96/AnimeClassifiers"
MODELS = [
	"CCAnime-ChromaticAberration-v1.16",
]
article = """\
These are classifiers meant to work with anime images.

For more information, you can check out the [Huggingface Hub](https://huggingface.co/city96/AnimeClassifiers) or [GitHub page](https://github.com/city96/CityClassifiers).
"""
info_default="""\
Include default class (unknown/negative) in output results.
"""
info_tiling = """\
Divide the image into parts and run classifier on each part separately.
Greatly improves accuracy but slows down inference.
"""
info_tiling_combine = """\
How to combine the confidence scores of the different tiles.
Mean averages confidence over all tiles. Median takes the value in the middle.
Max/min take the score from the tile with the highest/lowest confidence respectively, but can results in multiple labels having very high/very low confidence scores.
"""

pipeline_args = {}
if torch.cuda.is_available():
	pipeline_args.update({
		"device"     : "cuda",
		"clip_dtype" : torch.float16,
	})

pipeline = CityClassifierMultiModelPipeline(
	model_paths = [get_model_path(x, HFREPO, TOKEN) for x in MODELS],
	config_paths = [get_model_path(x, HFREPO, TOKEN, extension="config.json") for x in MODELS],
	**pipeline_args,
)
gr.Interface(
	fn      = pipeline,
	title   = "CityClassifiers demo",
	article = article,
	inputs  = [
		gr.Image(label="Input image", type="pil"),
		gr.Checkbox(label="Include default", value=True, info=info_default),
		gr.Checkbox(label="Tiling", value=True, info=info_tiling),
		gr.Dropdown(
			label   = "Tiling combine strategy",
			choices = ["mean", "median", "max", "min"],
			value = "mean",
			type = "value",
			info = info_tiling_combine,
		)
	],
	outputs = [gr.Label(label=x) for x in MODELS],
	examples = "./examples" if os.path.isdir("./examples") else None,
	allow_flagging = "never",
	analytics_enabled = False,
).launch()