Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from typing import Any | |
import pandas as pd | |
import streamlit as st | |
from countryinfo import CountryInfo | |
from dotenv import load_dotenv | |
from common import HintType, configs, get_distance | |
from hint import AudioHint, ImageHint, TextHint | |
def setup_models(_cache: Any, configs: dict) -> None: | |
"""Setups all hint models. | |
Args: | |
_cache (st.session_state): Streamlit cache object | |
configs (dict): Configurations used by the models | |
""" | |
for model_type in _cache["hint_types"]: | |
if _cache["model"][model_type] is None: | |
if model_type == HintType.TEXT.value: | |
_cache["model"][model_type] = setup_text_hint(configs) | |
elif model_type == HintType.IMAGE.value: | |
_cache["model"][model_type] = setup_image_hint(configs) | |
elif model_type == HintType.AUDIO.value: | |
_cache["model"][model_type] = setup_audio_hint(configs) | |
def setup_text_hint(configs: dict) -> TextHint: | |
"""Setups the text hint model. | |
Args: | |
configs (dict): Configurations used by the model | |
Returns: | |
TextHint: Hint model | |
""" | |
with st.spinner("Loading text model..."): | |
model_configs = configs["local"][HintType.TEXT.value.lower()] | |
model_configs["hf_access_token"] = os.environ["HF_ACCESS_TOKEN"] | |
textHint = TextHint(configs=model_configs) | |
textHint.initialize() | |
return textHint | |
def setup_image_hint(configs: dict) -> ImageHint: | |
"""Setups the image hint model. | |
Args: | |
configs (dict): Configurations used by the model | |
Returns: | |
ImageHint: Hint model | |
""" | |
with st.spinner("Loading image model..."): | |
model_configs = configs["local"][HintType.IMAGE.value.lower()] | |
imageHint = ImageHint(configs=model_configs) | |
imageHint.initialize() | |
return imageHint | |
def setup_audio_hint(configs: dict) -> AudioHint: | |
"""Setups the audio hint model. | |
Args: | |
configs (dict): Configurations used by the model | |
Returns: | |
AudioHint: Hint model | |
""" | |
with st.spinner("Loading audio model..."): | |
model_configs = configs["local"][HintType.AUDIO.value.lower()] | |
audioHint = AudioHint(configs=model_configs) | |
audioHint.initialize() | |
return audioHint | |
def get_country_list() -> pd.DataFrame: | |
"""Builds a database of countries and metadata. | |
Returns: | |
pd.DataFrame: Country database | |
""" | |
country_list = list(CountryInfo().all().keys()) | |
country_df = {} | |
for country in country_list: | |
try: | |
area = CountryInfo(country).area() | |
country_df[country] = area | |
except: | |
pass | |
country_df = pd.DataFrame(country_df.items(), columns=["country", "area"]) | |
return country_df | |
def pick_country(country_df: pd.DataFrame) -> str: | |
"""Selects a country, the probability of each country is related to its area size. | |
Args: | |
country_df (pd.DataFrame): Database of country and their metadata | |
Returns: | |
str: The selected country | |
""" | |
country = country_df.sample(n=1, weights="area")["country"].iloc[0] | |
return country | |
def reset_cache() -> None: | |
"""Reset the Streamlit APP cache.""" | |
country_df = get_country_list() | |
st.session_state["country_list"] = country_df["country"].values.tolist() | |
st.session_state["country"] = pick_country(country_df) | |
st.session_state["hint_types"] = [] | |
st.session_state["n_hints"] = 1 | |
st.session_state["game_started"] = False | |
st.session_state["model"] = { | |
HintType.TEXT.value: None, | |
HintType.IMAGE.value: None, | |
HintType.AUDIO.value: None, | |
} | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
st.set_page_config( | |
page_title="Gen AI GeoGuesser", | |
page_icon="π", | |
) | |
if not st.session_state: | |
load_dotenv() | |
reset_cache() | |
st.title("Generative AI GeoGuesser π") | |
st.markdown("### Guess the country based on hints generated by AI") | |
st.markdown("(Only working with image hints for performance reasons)") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.session_state["hint_types"] = st.multiselect( | |
"Chose which hint types you want", | |
# [x.value for x in HintType], | |
[HintType.IMAGE.value], | |
default=st.session_state["hint_types"], | |
) | |
with col2: | |
st.session_state["n_hints"] = st.slider( | |
"Number of hints", | |
min_value=1, | |
max_value=5, | |
value=st.session_state["n_hints"], | |
) | |
start_btn = st.button("Start game") | |
if start_btn: | |
if not st.session_state["hint_types"]: | |
st.error("Pick at least one hint type") | |
reset_cache() | |
else: | |
print(f'Chosen country "{st.session_state["country"]}"') | |
setup_models(st.session_state, configs) | |
for hint_type in st.session_state["hint_types"]: | |
with st.spinner(f"Generating {hint_type} hint..."): | |
st.session_state["model"][hint_type].generate_hint( | |
st.session_state["country"], | |
st.session_state["n_hints"], | |
) | |
st.session_state["game_started"] = True | |
if st.session_state["game_started"]: | |
game_col1, game_col2, game_col3 = st.columns([2, 1, 1]) | |
with game_col1: | |
guess = st.selectbox("Country guess", ([""] + st.session_state["country_list"])) | |
with game_col2: | |
guess_btn = st.button("Make a guess") | |
with game_col3: | |
reset_btn = st.button("Reset game") | |
if guess_btn: | |
if st.session_state["country"] == guess: | |
st.success("Correct guess you won!") | |
st.balloons() | |
else: | |
if guess: | |
country_latlong = CountryInfo(st.session_state["country"]).latlng() | |
guess_latlong = CountryInfo(guess).latlng() | |
distance = int(get_distance(country_latlong, guess_latlong)) | |
st.error( | |
f""" | |
Wrong guess, you missed the correct country by {distance} KM. | |
The correct answer was {st.session_state["country"]}. | |
""" | |
) | |
else: | |
st.error("Pick a country.") | |
if reset_btn: | |
reset_cache() | |
if st.session_state["game_started"]: | |
tabs = st.tabs([f"{x} hint" for x in st.session_state["hint_types"]]) | |
for tab_idx, tab in enumerate(tabs): | |
hint_type = st.session_state["hint_types"][tab_idx] | |
with tab: | |
if st.session_state["model"][hint_type]: | |
for hint_idx, hint in enumerate( | |
st.session_state["model"][hint_type].hints | |
): | |
st.markdown(f"#### Hint #{hint_idx+1}") | |
if hint_type == HintType.TEXT.value: | |
st.write(hint["text"]) | |
elif hint_type == HintType.IMAGE.value: | |
st.image(hint["image"]) | |
elif hint_type == HintType.AUDIO.value: | |
st.audio(hint["audio"], sample_rate=hint["sample_rate"]) | |