Dimitre's picture
Enabling image hints only
aafaa22
import logging
import os
from typing import Any
import pandas as pd
import streamlit as st
from countryinfo import CountryInfo
from dotenv import load_dotenv
from common import HintType, configs, get_distance
from hint import AudioHint, ImageHint, TextHint
def setup_models(_cache: Any, configs: dict) -> None:
"""Setups all hint models.
Args:
_cache (st.session_state): Streamlit cache object
configs (dict): Configurations used by the models
"""
for model_type in _cache["hint_types"]:
if _cache["model"][model_type] is None:
if model_type == HintType.TEXT.value:
_cache["model"][model_type] = setup_text_hint(configs)
elif model_type == HintType.IMAGE.value:
_cache["model"][model_type] = setup_image_hint(configs)
elif model_type == HintType.AUDIO.value:
_cache["model"][model_type] = setup_audio_hint(configs)
@st.cache_resource()
def setup_text_hint(configs: dict) -> TextHint:
"""Setups the text hint model.
Args:
configs (dict): Configurations used by the model
Returns:
TextHint: Hint model
"""
with st.spinner("Loading text model..."):
model_configs = configs["local"][HintType.TEXT.value.lower()]
model_configs["hf_access_token"] = os.environ["HF_ACCESS_TOKEN"]
textHint = TextHint(configs=model_configs)
textHint.initialize()
return textHint
@st.cache_resource()
def setup_image_hint(configs: dict) -> ImageHint:
"""Setups the image hint model.
Args:
configs (dict): Configurations used by the model
Returns:
ImageHint: Hint model
"""
with st.spinner("Loading image model..."):
model_configs = configs["local"][HintType.IMAGE.value.lower()]
imageHint = ImageHint(configs=model_configs)
imageHint.initialize()
return imageHint
@st.cache_resource()
def setup_audio_hint(configs: dict) -> AudioHint:
"""Setups the audio hint model.
Args:
configs (dict): Configurations used by the model
Returns:
AudioHint: Hint model
"""
with st.spinner("Loading audio model..."):
model_configs = configs["local"][HintType.AUDIO.value.lower()]
audioHint = AudioHint(configs=model_configs)
audioHint.initialize()
return audioHint
@st.cache_resource()
def get_country_list() -> pd.DataFrame:
"""Builds a database of countries and metadata.
Returns:
pd.DataFrame: Country database
"""
country_list = list(CountryInfo().all().keys())
country_df = {}
for country in country_list:
try:
area = CountryInfo(country).area()
country_df[country] = area
except:
pass
country_df = pd.DataFrame(country_df.items(), columns=["country", "area"])
return country_df
def pick_country(country_df: pd.DataFrame) -> str:
"""Selects a country, the probability of each country is related to its area size.
Args:
country_df (pd.DataFrame): Database of country and their metadata
Returns:
str: The selected country
"""
country = country_df.sample(n=1, weights="area")["country"].iloc[0]
return country
def reset_cache() -> None:
"""Reset the Streamlit APP cache."""
country_df = get_country_list()
st.session_state["country_list"] = country_df["country"].values.tolist()
st.session_state["country"] = pick_country(country_df)
st.session_state["hint_types"] = []
st.session_state["n_hints"] = 1
st.session_state["game_started"] = False
st.session_state["model"] = {
HintType.TEXT.value: None,
HintType.IMAGE.value: None,
HintType.AUDIO.value: None,
}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
st.set_page_config(
page_title="Gen AI GeoGuesser",
page_icon="🌎",
)
if not st.session_state:
load_dotenv()
reset_cache()
st.title("Generative AI GeoGuesser 🌎")
st.markdown("### Guess the country based on hints generated by AI")
st.markdown("(Only working with image hints for performance reasons)")
col1, col2 = st.columns([2, 1])
with col1:
st.session_state["hint_types"] = st.multiselect(
"Chose which hint types you want",
# [x.value for x in HintType],
[HintType.IMAGE.value],
default=st.session_state["hint_types"],
)
with col2:
st.session_state["n_hints"] = st.slider(
"Number of hints",
min_value=1,
max_value=5,
value=st.session_state["n_hints"],
)
start_btn = st.button("Start game")
if start_btn:
if not st.session_state["hint_types"]:
st.error("Pick at least one hint type")
reset_cache()
else:
print(f'Chosen country "{st.session_state["country"]}"')
setup_models(st.session_state, configs)
for hint_type in st.session_state["hint_types"]:
with st.spinner(f"Generating {hint_type} hint..."):
st.session_state["model"][hint_type].generate_hint(
st.session_state["country"],
st.session_state["n_hints"],
)
st.session_state["game_started"] = True
if st.session_state["game_started"]:
game_col1, game_col2, game_col3 = st.columns([2, 1, 1])
with game_col1:
guess = st.selectbox("Country guess", ([""] + st.session_state["country_list"]))
with game_col2:
guess_btn = st.button("Make a guess")
with game_col3:
reset_btn = st.button("Reset game")
if guess_btn:
if st.session_state["country"] == guess:
st.success("Correct guess you won!")
st.balloons()
else:
if guess:
country_latlong = CountryInfo(st.session_state["country"]).latlng()
guess_latlong = CountryInfo(guess).latlng()
distance = int(get_distance(country_latlong, guess_latlong))
st.error(
f"""
Wrong guess, you missed the correct country by {distance} KM.
The correct answer was {st.session_state["country"]}.
"""
)
else:
st.error("Pick a country.")
if reset_btn:
reset_cache()
if st.session_state["game_started"]:
tabs = st.tabs([f"{x} hint" for x in st.session_state["hint_types"]])
for tab_idx, tab in enumerate(tabs):
hint_type = st.session_state["hint_types"][tab_idx]
with tab:
if st.session_state["model"][hint_type]:
for hint_idx, hint in enumerate(
st.session_state["model"][hint_type].hints
):
st.markdown(f"#### Hint #{hint_idx+1}")
if hint_type == HintType.TEXT.value:
st.write(hint["text"])
elif hint_type == HintType.IMAGE.value:
st.image(hint["image"])
elif hint_type == HintType.AUDIO.value:
st.audio(hint["audio"], sample_rate=hint["sample_rate"])