Upload 2 files
Browse files- app.py +98 -0
- caption_index.parquet +3 -0
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from concurrent.futures import ThreadPoolExecutor
|
5 |
+
import os
|
6 |
+
|
7 |
+
kaomojis = [
|
8 |
+
"0_0",
|
9 |
+
"(o)_(o)",
|
10 |
+
"+_+",
|
11 |
+
"+_-",
|
12 |
+
"._.",
|
13 |
+
"<o>_<o>",
|
14 |
+
"<|>_<|>",
|
15 |
+
"=_=",
|
16 |
+
">_<",
|
17 |
+
"3_3",
|
18 |
+
"6_9",
|
19 |
+
">_o",
|
20 |
+
"@_@",
|
21 |
+
"^_^",
|
22 |
+
"o_o",
|
23 |
+
"u_u",
|
24 |
+
"x_x",
|
25 |
+
"|_|",
|
26 |
+
"||_||",
|
27 |
+
]
|
28 |
+
|
29 |
+
index_file = './caption_index.parquet'
|
30 |
+
|
31 |
+
|
32 |
+
df = pd.read_parquet(index_file)
|
33 |
+
|
34 |
+
|
35 |
+
def process_input(user_input):
|
36 |
+
user_tags = set(tag.replace(' ', '_') for tag in user_input.split(', '))
|
37 |
+
|
38 |
+
def match_tags(caption, tags):
|
39 |
+
caption_set = set(caption.split(', '))
|
40 |
+
return tags.issubset(caption_set)
|
41 |
+
|
42 |
+
def process_chunk(chunk):
|
43 |
+
chunk = chunk.copy()
|
44 |
+
chunk['match'] = chunk.index.to_series().apply(lambda x: match_tags(x, user_tags))
|
45 |
+
return chunk[chunk['match']]
|
46 |
+
|
47 |
+
chunk_size = 100000
|
48 |
+
chunks = [df.iloc[i:i + chunk_size] for i in range(0, df.shape[0], chunk_size)]
|
49 |
+
|
50 |
+
with ThreadPoolExecutor(max_workers=8) as executor:
|
51 |
+
results = executor.map(process_chunk, chunks)
|
52 |
+
|
53 |
+
filtered_df = pd.concat(results)
|
54 |
+
|
55 |
+
def calculate_weight(score):
|
56 |
+
try:
|
57 |
+
weight = float(score) - 5
|
58 |
+
return max(weight, 0)
|
59 |
+
except ValueError:
|
60 |
+
return 0
|
61 |
+
|
62 |
+
filtered_df['weight'] = filtered_df['score'].apply(calculate_weight)
|
63 |
+
|
64 |
+
random_seed = np.random.randint(0, 1000000)
|
65 |
+
np.random.seed(random_seed)
|
66 |
+
|
67 |
+
sample_size = min(5, len(filtered_df))
|
68 |
+
|
69 |
+
if sample_size > 0:
|
70 |
+
weights = filtered_df['weight'].to_numpy()
|
71 |
+
weights /= weights.sum()
|
72 |
+
sampled_indices = np.random.choice(filtered_df.index, size=sample_size, p=weights, replace=False)
|
73 |
+
sampled_df = filtered_df.loc[sampled_indices]
|
74 |
+
else:
|
75 |
+
sampled_df = filtered_df
|
76 |
+
|
77 |
+
output = []
|
78 |
+
for index, row in sampled_df.iterrows():
|
79 |
+
tags = index.split(', ')
|
80 |
+
processed_tags = [tag.replace('_', ' ') if tag not in kaomojis else tag for tag in tags]
|
81 |
+
processed_tags = [tag.replace("(", "\(").replace(")", "\)") for tag in processed_tags]
|
82 |
+
processed_caption = ', '.join(processed_tags)
|
83 |
+
row['name'] = row['name'].replace('danbooru_', 'https://danbooru.donmai.us/posts/')
|
84 |
+
output.append(f"<a href='{row['name']}' target='_blank'>{row['name']}</a>: {processed_caption}<br>")
|
85 |
+
|
86 |
+
return ''.join(output), len(filtered_df)
|
87 |
+
|
88 |
+
iface = gr.Interface(
|
89 |
+
fn=process_input,
|
90 |
+
inputs=gr.Textbox(label="Input tags separated by ', '"),
|
91 |
+
outputs=[
|
92 |
+
gr.HTML(),
|
93 |
+
gr.Number(label="Matched Images Count")
|
94 |
+
],
|
95 |
+
title="Prompt Sampling",
|
96 |
+
flagging_mode='never'
|
97 |
+
)
|
98 |
+
iface.launch()
|
caption_index.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6ce8a0c716655604d747131a76984a35dc6f15487e038242d640367a8df66db
|
3 |
+
size 86522444
|