Dimitre commited on
Commit
bcb83c0
1 Parent(s): 8147de4

Initial version of the app

Browse files
Files changed (6) hide show
  1. configs.yaml +7 -0
  2. requirements.txt +5 -0
  3. src/app.py +132 -0
  4. src/common.py +28 -0
  5. src/hangman.py +35 -0
  6. src/hf_utils.py +109 -0
configs.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ generation_config:
2
+ max_output_tokens: 256
3
+ temperature: 1
4
+ top_p: 1
5
+ top_k: 32
6
+ os_model: google/gemma-2b-it
7
+ device: cpu
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ python-dotenv
3
+ torch
4
+ transformers
5
+ accelerate
src/app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import streamlit as st
5
+ import torch
6
+ from dotenv import load_dotenv
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ from common import CATEGORIES, MAX_TRIES, configs
10
+ from hangman import guess_letter
11
+ from hf_utils import query_hint, query_word
12
+
13
+
14
+ @st.cache_resource()
15
+ def setup(model_id: str, device: str) -> None:
16
+ """Initializes the model and tokenizer.
17
+
18
+ Args:
19
+ model_id (str): Model ID used to load the tokenizer and model.
20
+ """
21
+ logger.info(f"Loading model and tokenizer from model: '{model_id}'")
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ model_id,
24
+ token=os.environ["HF_ACCESS_TOKEN"],
25
+ )
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ torch_dtype=torch.float16,
29
+ token=os.environ["HF_ACCESS_TOKEN"],
30
+ ).to(device)
31
+ logger.info("Setup finished")
32
+ return {"tokenizer": tokenizer, "model": model}
33
+
34
+
35
+ logging.basicConfig(level=logging.INFO)
36
+ logger = logging.getLogger(__file__)
37
+
38
+ st.set_page_config(
39
+ page_title="Gemma Hangman",
40
+ page_icon="🧩",
41
+ )
42
+
43
+ load_dotenv()
44
+ assets = setup(configs["os_model"], configs["device"])
45
+
46
+ tokenizer = assets["tokenizer"]
47
+ model = assets["model"]
48
+
49
+ if not st.session_state:
50
+ st.session_state["word"] = ""
51
+ st.session_state["hint"] = ""
52
+ st.session_state["hangman"] = ""
53
+ st.session_state["missed_letters"] = []
54
+ st.session_state["correct_letters"] = []
55
+
56
+ st.title("Gemini Hangman")
57
+
58
+ st.markdown("## Guess the word based on a hint")
59
+
60
+ col1, col2 = st.columns(2)
61
+
62
+ with col1:
63
+ category = st.selectbox(
64
+ "Choose a category",
65
+ CATEGORIES,
66
+ )
67
+
68
+ with col2:
69
+ start_btn = st.button("Start game")
70
+ reset_btn = st.button("Reset game")
71
+
72
+ if start_btn:
73
+ st.session_state["word"] = query_word(
74
+ category,
75
+ model,
76
+ tokenizer,
77
+ configs["generation_config"],
78
+ configs["device"],
79
+ )
80
+ st.session_state["hint"] = query_hint(
81
+ st.session_state["word"],
82
+ model,
83
+ tokenizer,
84
+ configs["generation_config"],
85
+ configs["device"],
86
+ )
87
+ st.session_state["hangman"] = "_" * len(st.session_state["word"])
88
+ st.session_state["missed_letters"] = []
89
+ st.session_state["correct_letters"] = []
90
+
91
+ if reset_btn:
92
+ st.session_state["word"] = ""
93
+ st.session_state["hint"] = ""
94
+ st.session_state["hangman"] = ""
95
+ st.session_state["missed_letters"] = []
96
+ st.session_state["correct_letters"] = []
97
+
98
+ st.markdown(
99
+ """
100
+ ## Guess the word based on a hint
101
+ Note: you must input whitespaces and special characters.
102
+ """
103
+ )
104
+
105
+ st.markdown(f'### Hint:\n{st.session_state["hint"]}')
106
+
107
+ col3, col4 = st.columns(2)
108
+
109
+ with col3:
110
+ guess = st.text_input(label="Enter letter")
111
+ guess_btn = st.button("Guess letter")
112
+
113
+ if guess_btn:
114
+ st.session_state = guess_letter(guess, st.session_state)
115
+
116
+ with col4:
117
+ hangman = st.text_input(
118
+ label="Hangman",
119
+ value=st.session_state["hangman"],
120
+ )
121
+ st.text_input(
122
+ label=f"Missed letters (max {MAX_TRIES} tries)",
123
+ value=", ".join(st.session_state["missed_letters"]),
124
+ )
125
+
126
+ if st.session_state["word"] == st.session_state["hangman"] != "":
127
+ st.success("You won!")
128
+ st.balloons()
129
+
130
+ if len(st.session_state["missed_letters"]) >= MAX_TRIES:
131
+ st.error(f"""You lost, the correct word was '{st.session_state["word"]}'""")
132
+ st.snow()
src/common.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pprint
3
+
4
+ import yaml
5
+
6
+
7
+ def parse_configs(configs_path: str) -> dict:
8
+ """Parse configs from the YAML file.
9
+
10
+ Args:
11
+ configs_path (str): Path to the YAML file
12
+
13
+ Returns:
14
+ dict: Parsed configs
15
+ """
16
+ configs = yaml.safe_load(open(configs_path, "r"))
17
+ logger.info(f"Configs: {pprint.pformat(configs)}")
18
+ return configs
19
+
20
+
21
+ CONFIGS_PATH = "configs.yaml"
22
+ MAX_TRIES = 6
23
+ CATEGORIES = ["Country", "Animal", "Food", "Movie"]
24
+
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__file__)
27
+
28
+ configs = parse_configs(CONFIGS_PATH)
src/hangman.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from streamlit import session_state
4
+
5
+
6
+ def guess_letter(letter: str, session: session_state) -> session_state:
7
+ """Take a letter and evaluate if it is part of the hangman puzzle
8
+ then updates the session object accordingly.
9
+
10
+ Args:Chosen letter
11
+ letter (str): Streamlit session object
12
+ session (session_state): _description_
13
+
14
+ Returns:
15
+ session_state: Updated session
16
+ """
17
+ logger.info(f"Letter '{letter}' picked")
18
+ if letter in session["word"]:
19
+ session["correct_letters"].append(letter)
20
+ else:
21
+ session["missed_letters"].append(letter)
22
+
23
+ hangman = "".join(
24
+ [
25
+ (letter if letter in session["correct_letters"] else "_")
26
+ for letter in session["word"]
27
+ ]
28
+ )
29
+ session["hangman"] = hangman
30
+ logger.info("Session state updated")
31
+ return session
32
+
33
+
34
+ logging.basicConfig(level=logging.INFO)
35
+ logger = logging.getLogger(__file__)
src/hf_utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import string
4
+
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
6
+
7
+ GEMMA_WORD_PATTERNS = [
8
+ "(?<=\*)(.*?)(?=\*)",
9
+ '(?<=")(.*?)(?=")',
10
+ ]
11
+
12
+
13
+ def query_hf(
14
+ query: str,
15
+ model: AutoModelForCausalLM,
16
+ tokenizer: AutoTokenizer,
17
+ generation_config: dict,
18
+ device: str,
19
+ ) -> str:
20
+ """Queries an LLM model using the Vertex AI API.
21
+
22
+ Args:
23
+ query (str): Query sent to the Vertex API
24
+ model (str): Model target by Vertex
25
+ generation_config (dict): Configurations used by the model
26
+
27
+ Returns:
28
+ str: Vertex AI text response
29
+ """
30
+ generation_config = GenerationConfig(
31
+ do_sample=True,
32
+ max_new_tokens=generation_config["max_output_tokens"],
33
+ top_k=generation_config["top_k"],
34
+ top_p=generation_config["top_p"],
35
+ temperature=generation_config["temperature"],
36
+ )
37
+
38
+ input_ids = tokenizer(query, return_tensors="pt").to(device)
39
+ outputs = model.generate(**input_ids, generation_config=generation_config)
40
+ outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+ outputs = outputs.replace(query, "")
42
+ return outputs
43
+
44
+
45
+ def query_word(
46
+ category: str,
47
+ model: AutoModelForCausalLM,
48
+ tokenizer: AutoTokenizer,
49
+ generation_config: dict,
50
+ device: str,
51
+ ) -> str:
52
+ """Queries a word to be used for the hangman game.
53
+
54
+ Args:
55
+ category (str): Category used as source sample a word
56
+ model (str): Model target by Vertex
57
+ generation_config (dict): Configurations used by the model
58
+
59
+ Returns:
60
+ str: Queried word
61
+ """
62
+ logger.info(f"Quering word for category: '{category}'...")
63
+ query = f"Name a single existing {category}."
64
+
65
+ matched_word = ""
66
+ while not matched_word:
67
+ word = query_hf(query, model, tokenizer, generation_config, device)
68
+
69
+ # Extract word of interest from Gemma's output
70
+ for pattern in GEMMA_WORD_PATTERNS:
71
+ matched_words = re.findall(rf"{pattern}", word)
72
+ matched_words = [x for x in matched_words if x != ""]
73
+ if matched_words:
74
+ matched_word = matched_words[-1]
75
+
76
+ matched_word = matched_word.translate(str.maketrans("", "", string.punctuation))
77
+ matched_word = matched_word.lower()
78
+
79
+ logger.info("Word queried successful")
80
+ return matched_word
81
+
82
+
83
+ def query_hint(
84
+ word: str,
85
+ model: AutoModelForCausalLM,
86
+ tokenizer: AutoTokenizer,
87
+ generation_config: dict,
88
+ device: str,
89
+ ) -> str:
90
+ """Queries a hint for the hangman game.
91
+
92
+ Args:
93
+ word (str): Word used as source to create the hint
94
+ model (str): Model target by Vertex
95
+ generation_config (dict): Configurations used by the model
96
+
97
+ Returns:
98
+ str: Queried hint
99
+ """
100
+ logger.info(f"Quering hint for word: '{word}'...")
101
+ query = f"Describe the word '{word}' without mentioning it."
102
+ hint = query_hf(query, model, tokenizer, generation_config, device)
103
+ hint = re.sub(re.escape(word), "***", hint, flags=re.IGNORECASE)
104
+ logger.info("Hint queried successful")
105
+ return hint
106
+
107
+
108
+ logging.basicConfig(level=logging.INFO)
109
+ logger = logging.getLogger(__file__)