oh-my-dear-ai's picture
refactor(app): update task handling and Base64 conversion functions
ad54bec
raw
history blame
11.4 kB
import base64
import datetime
import gradio as gr
import pandas as pd
import pytz
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
# import logging
# import plotly.io as pio
# pio.renderers.default = "browser"
# logger = logging.getLogger(__name__)
# logger.setLevel(logging.INFO)
df = pd.read_csv("herbologist_almanac_checklist_data.csv")
TASKS = [
"task1",
"task2",
"task3",
"task4",
"color1",
"color2",
"color3",
"color4",
"color5",
"color6",
]
PLANTS = list(df["plant"].unique())
def bin_ls2base64(ls):
# 将二进制列表转换为base64字符串
binary_str = "".join(str(n) for n in ls)
# decimal_int = int(bin_str, 2)
byte_str = int(binary_str, 2).to_bytes((len(binary_str) + 7) // 8, byteorder="big")
base64_str = base64.b64encode(byte_str).decode("utf-8")
return base64_str
def base64_to_binary(base64_str):
if isinstance(base64_str, str):
# 将base64字符串转换为二进制列表
byte_str = base64.b64decode(base64_str)
binary_str = bin(int.from_bytes(byte_str, byteorder="big"))[2:].zfill(
len(TASKS) * len(PLANTS)
)
# ls = [int(n) for n in binary_str]
return binary_str
else:
raise TypeError(
"Invalid input type. Expected str, got {}".format(type(base64_str))
)
def parse_token(token):
if not token: # 处理空字符串的情况
token = "\x00"
if len(token) > 0:
payload = base64_to_binary(token)
almanac_data: list = [int(n) for n in payload]
# print(len(almanac_data))
parsed_dict = {}
for _, row in df.iterrows():
parsed_dict[row["plant"]] = [almanac_data.pop() for _ in range(len(TASKS))]
return parsed_dict
else:
return parse_token("\x00")
# 定义一个简单的函数,模拟接收DataFrame数据
def process_data(*args):
plot_library = args[-1]
almanac_dict = dict(zip(PLANTS, args[:-1]))
almanac_df = df.filter(items=["plant", "name"] + TASKS)
almanac_bin_ls = []
for pl in PLANTS:
plant_tasks = almanac_dict[pl]
plant_tasks_done = [0 for _ in range(len(TASKS))]
for n, i in enumerate(plant_tasks):
plant_tasks_done[n] = 1
almanac_df.loc[almanac_df["plant"] == pl, TASKS[i]] = "✔"
almanac_bin_ls += plant_tasks_done
almanac_reverse_64 = bin_ls2base64(reversed(almanac_bin_ls))
# logger.info("Generating image!")
return (
(
generate_img_by_plotly(almanac_df.drop(columns=df.columns[0]))
if plot_library == "plotly"
else generate_img_by_matplotlib(almanac_df.drop(columns=df.columns[0]))
),
almanac_reverse_64,
)
def show_checkbox_groups(token):
checklist_inputs = []
parsed_dict = parse_token(token)
for index, row in df.iterrows():
tasks = [
row[col][0].upper() + row[col][1:]
for col in TASKS
if df.notnull().at[index, col]
]
with gr.Row():
checkbox = gr.CheckboxGroup(
tasks,
label=f"{row['name']}",
value=[
tasks[i]
for i, v in enumerate(parsed_dict[row["plant"]])
if v == 1 and df.notnull().at[index, TASKS[i]]
],
type="index",
)
checklist_inputs.append(checkbox)
# logger.info("Rerendering checklist!")
return checklist_inputs
"""使用matplotlib生成图片"""
def wrap_text(text, max_width=20):
"""Manually wrap text based on max_width (character count)"""
wrapped_lines = []
words = text.split(" ")
line = ""
for word in words:
if len(line) + len(word) + 1 <= max_width:
line += word + " "
else:
wrapped_lines.append(line.strip())
line = word + " "
wrapped_lines.append(line.strip())
return "\n".join(wrapped_lines)
def generate_img_by_matplotlib(df):
fig, ax = plt.subplots(
figsize=(12, 16)
) # Adjust the figure size for better readability
ax.axis("off") # Turn off the axis
# Create a table with a light grey background for the header
table = ax.table(
cellText=df.values,
colLabels=df.columns,
loc="center",
colColours=["#f2f2f2"] * len(df.columns),
colWidths=[w / len(df.columns) for w in [1] + [2] * 4 + [1] * 6],
)
table.auto_set_font_size(False)
table.set_fontsize(10)
# Apply text wrapping and center alignment to non-header cells
for (row, col), cell in table.get_celld().items():
if cell.get_text().get_text() == "nan":
cell.set_text_props(text="", ha="center")
if row == 0:
cell.set_text_props(
weight="bold", ha="center"
) # Bold and center-align header text
if col == 0:
text = cell.get_text().get_text()
wrapped_text = wrap_text(text, 9)
cell.set_text_props(text=wrapped_text, ha="center", linespacing=1)
else:
text = cell.get_text().get_text()
wrapped_text = wrap_text(text)
cell.set_text_props(
text=wrapped_text,
ha="center",
fontsize=10,
linespacing=1,
) # Enable text wrapping and center-align
# Manually scale table if needed for better readability with wrapped text
table.scale(1, 4) # Adjust row height
# Ensure the layout is adjusted properly
plt.tight_layout(pad=0.5) # Increase padding slightly
# Save the figure to a bytes buffer
buffer = BytesIO()
fig.savefig(
buffer, format="png", pad_inches=0.2
) # Adjust padding around the figure
plt.close(fig) # Close the figure to free up memory
buffer.seek(0)
image = Image.open(buffer)
# logger.info("Image generated by matplotlib successfully!")
return image
"""使用plotly生成图片"""
def color_mapping(color):
color_hex = {
"turquoise": "#28b6aa",
"chartreuse": "#dcde82",
"red": "#981c05",
"yellow": "#f4ca3a",
"pink": "#fd8d9b",
"blue": "#4194bd",
"white": "#ffffff",
"black": "#b7b7b7",
"orange": "#d97413",
"purple": "#8a659a",
"viridian": "#a1c42a",
}
# if color == "" or color == "✔":
# return "white"
return color_hex.get(color, "white")
def styled_header(header):
return dict(
values=[[f"<b>{attr.upper()}</b>"] for attr in header],
line_color="darkslategray",
fill_color="royalblue",
align=["center"],
font=dict(color="white", size=10),
# height=40
)
def styled_cells(cells):
return dict(
values=cells,
line_color="darkslategray",
fill_color=["lavender", "white", "white", "white", "white"]
+ [[color_mapping(str(el)) for el in col] for col in cells[5:]],
align="center",
font_size=10,
height=30,
)
def handle_element(el):
emoji_mapping = {
"harvest": "🌱",
"light": "🌞",
"moisture": "💧",
"mood": "💗",
"sell": "💲",
"collect": "🖐️",
"hygiene": "🧽",
"pest": "🐛",
"overgrowth": "🌿",
"show": "👩🏻‍🌾",
}
if isinstance(el, str):
if el == "✔":
return "✅"
for cond in emoji_mapping.keys():
if cond in el.lower():
new_text = el + emoji_mapping[cond]
return new_text
return el
def generate_img_by_plotly(df):
df.fillna("", inplace=True)
cells_values = [df[col].to_list() for col in df.columns]
header_values = list(df.columns)
# add emoji
cells_values = [[handle_element(el) for el in col] for col in cells_values]
# Create the table figure
fig = go.Figure(
data=[
go.Table(
header=styled_header(header_values),
cells=styled_cells(cells_values),
columnwidth=[80, 160, 160, 160, 160, 80, 80, 80, 80, 80, 80],
)
]
)
fig.add_annotation(
text=f'Herbology Almanac Checklist Generated on(GMT): {datetime.datetime.now(tz=pytz.utc).strftime("%Y-%m-%d %H:%M:%S")}',
xref="paper",
yref="paper",
x=0.5,
y=0,
showarrow=False,
)
fig.update_layout(
width=1200,
height=1200,
margin=dict(l=20, r=20, t=20, b=20),
# title_text=f'Herbology Almanac Checklist Generated on(GMT): {datetime.datetime.now(tz=pytz.utc).strftime("%Y-%m-%d %H:%M:%S")}',
# annotations=[
# dict(
# text=f"Generated on(GMT): {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}",
# xref="paper",
# yref="paper",
# x=0.5, # 水平居中
# y=-0.1, # 距离图表底部的位置
# showarrow=False,
# font=dict(size=10),
# )
# ],
)
buffer = BytesIO()
fig.write_image(buffer, format="png")
buffer.seek(0)
image = Image.open(buffer)
# logger.info("Image generated by plotly successfully!")
return image
with gr.Blocks() as app:
gr.Markdown(
"""
<center><font size=8>👩🏻‍🌾Herbology Almanac Checklist Generator📝</font></center>
This is a simple web app that generates an almanac checklist for your plants.
"""
)
gr.Markdown(
"""
# RECOVERY TOKEN
"""
)
recovery_token = gr.Textbox(
value="",
label="Recovery Token",
info="Save this token or paste your saved one here",
placeholder="Keep this token to restore your previous input".upper(),
interactive=True,
)
gr.Markdown(
"""
# YOUR RESEARCH TASKS
"""
)
checklist_inputs = show_checkbox_groups(recovery_token.value)
gr.Markdown(
"""
# IMAGE STYLE
"""
)
plot_library = gr.Radio(
["plotly", "matplotlib"],
label="Plot Library",
value="plotly",
info="Choose your plot library",
)
submit_button = gr.Button("Generate Image and Token")
# df_out = gr.Dataframe(label="Output Dataframe", interactive=False)
generated_img = gr.Image(label="Generated Image", format="png", type="pil")
logs = gr.Markdown(
"""
# CHANGELOG
- 2024/08/31: Initial release
- 2024/09/03: Fix a mistake in the tasks of mimbulus
- 2024/09/04: Correct Radiant count for water hyacinth
- 2024/09/05: Support image generated by plotly
- 2024/09/13: Update Water Lily tasks and color options
"""
)
submit_button.click(
process_data,
inputs=checklist_inputs + [plot_library],
outputs=[generated_img, recovery_token],
)
recovery_token.change(
show_checkbox_groups,
inputs=[recovery_token],
outputs=checklist_inputs,
)
# generate_button.click(
# generate_img,
# inputs=[df_out],
# outputs=[generated_img],
# )
app.queue()
app.launch()