Dimitre commited on
Commit
1a21137
1 Parent(s): 1fd4dc6

removing unused files

Browse files
Files changed (5) hide show
  1. app.py +0 -1
  2. src/app.py +0 -132
  3. src/common.py +0 -28
  4. src/hangman.py +0 -35
  5. src/hf_utils.py +0 -109
app.py CHANGED
@@ -43,7 +43,6 @@ def setup(model_id: str, device: str) -> None:
43
  model_id,
44
  torch_dtype=torch.float16,
45
  token=os.environ["HF_ACCESS_TOKEN"],
46
- device_map="auto",
47
  ).to(device)
48
  logger.info("Setup finished")
49
  return {"tokenizer": tokenizer, "model": model}
 
43
  model_id,
44
  torch_dtype=torch.float16,
45
  token=os.environ["HF_ACCESS_TOKEN"],
 
46
  ).to(device)
47
  logger.info("Setup finished")
48
  return {"tokenizer": tokenizer, "model": model}
src/app.py DELETED
@@ -1,132 +0,0 @@
1
- import logging
2
- import os
3
-
4
- import streamlit as st
5
- import torch
6
- from dotenv import load_dotenv
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
-
9
- from common import CATEGORIES, MAX_TRIES, configs
10
- from hangman import guess_letter
11
- from hf_utils import query_hint, query_word
12
-
13
-
14
- @st.cache_resource()
15
- def setup(model_id: str, device: str) -> None:
16
- """Initializes the model and tokenizer.
17
-
18
- Args:
19
- model_id (str): Model ID used to load the tokenizer and model.
20
- """
21
- logger.info(f"Loading model and tokenizer from model: '{model_id}'")
22
- tokenizer = AutoTokenizer.from_pretrained(
23
- model_id,
24
- token=os.environ["HF_ACCESS_TOKEN"],
25
- )
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_id,
28
- torch_dtype=torch.float16,
29
- token=os.environ["HF_ACCESS_TOKEN"],
30
- ).to(device)
31
- logger.info("Setup finished")
32
- return {"tokenizer": tokenizer, "model": model}
33
-
34
-
35
- logging.basicConfig(level=logging.INFO)
36
- logger = logging.getLogger(__file__)
37
-
38
- st.set_page_config(
39
- page_title="Gemma Hangman",
40
- page_icon="🧩",
41
- )
42
-
43
- load_dotenv()
44
- assets = setup(configs["os_model"], configs["device"])
45
-
46
- tokenizer = assets["tokenizer"]
47
- model = assets["model"]
48
-
49
- if not st.session_state:
50
- st.session_state["word"] = ""
51
- st.session_state["hint"] = ""
52
- st.session_state["hangman"] = ""
53
- st.session_state["missed_letters"] = []
54
- st.session_state["correct_letters"] = []
55
-
56
- st.title("Gemini Hangman")
57
-
58
- st.markdown("## Guess the word based on a hint")
59
-
60
- col1, col2 = st.columns(2)
61
-
62
- with col1:
63
- category = st.selectbox(
64
- "Choose a category",
65
- CATEGORIES,
66
- )
67
-
68
- with col2:
69
- start_btn = st.button("Start game")
70
- reset_btn = st.button("Reset game")
71
-
72
- if start_btn:
73
- st.session_state["word"] = query_word(
74
- category,
75
- model,
76
- tokenizer,
77
- configs["generation_config"],
78
- configs["device"],
79
- )
80
- st.session_state["hint"] = query_hint(
81
- st.session_state["word"],
82
- model,
83
- tokenizer,
84
- configs["generation_config"],
85
- configs["device"],
86
- )
87
- st.session_state["hangman"] = "_" * len(st.session_state["word"])
88
- st.session_state["missed_letters"] = []
89
- st.session_state["correct_letters"] = []
90
-
91
- if reset_btn:
92
- st.session_state["word"] = ""
93
- st.session_state["hint"] = ""
94
- st.session_state["hangman"] = ""
95
- st.session_state["missed_letters"] = []
96
- st.session_state["correct_letters"] = []
97
-
98
- st.markdown(
99
- """
100
- ## Guess the word based on a hint
101
- Note: you must input whitespaces and special characters.
102
- """
103
- )
104
-
105
- st.markdown(f'### Hint:\n{st.session_state["hint"]}')
106
-
107
- col3, col4 = st.columns(2)
108
-
109
- with col3:
110
- guess = st.text_input(label="Enter letter")
111
- guess_btn = st.button("Guess letter")
112
-
113
- if guess_btn:
114
- st.session_state = guess_letter(guess, st.session_state)
115
-
116
- with col4:
117
- hangman = st.text_input(
118
- label="Hangman",
119
- value=st.session_state["hangman"],
120
- )
121
- st.text_input(
122
- label=f"Missed letters (max {MAX_TRIES} tries)",
123
- value=", ".join(st.session_state["missed_letters"]),
124
- )
125
-
126
- if st.session_state["word"] == st.session_state["hangman"] != "":
127
- st.success("You won!")
128
- st.balloons()
129
-
130
- if len(st.session_state["missed_letters"]) >= MAX_TRIES:
131
- st.error(f"""You lost, the correct word was '{st.session_state["word"]}'""")
132
- st.snow()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/common.py DELETED
@@ -1,28 +0,0 @@
1
- import logging
2
- import pprint
3
-
4
- import yaml
5
-
6
-
7
- def parse_configs(configs_path: str) -> dict:
8
- """Parse configs from the YAML file.
9
-
10
- Args:
11
- configs_path (str): Path to the YAML file
12
-
13
- Returns:
14
- dict: Parsed configs
15
- """
16
- configs = yaml.safe_load(open(configs_path, "r"))
17
- logger.info(f"Configs: {pprint.pformat(configs)}")
18
- return configs
19
-
20
-
21
- CONFIGS_PATH = "configs.yaml"
22
- MAX_TRIES = 6
23
- CATEGORIES = ["Country", "Animal", "Food", "Movie"]
24
-
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__file__)
27
-
28
- configs = parse_configs(CONFIGS_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/hangman.py DELETED
@@ -1,35 +0,0 @@
1
- import logging
2
-
3
- from streamlit import session_state
4
-
5
-
6
- def guess_letter(letter: str, session: session_state) -> session_state:
7
- """Take a letter and evaluate if it is part of the hangman puzzle
8
- then updates the session object accordingly.
9
-
10
- Args:Chosen letter
11
- letter (str): Streamlit session object
12
- session (session_state): _description_
13
-
14
- Returns:
15
- session_state: Updated session
16
- """
17
- logger.info(f"Letter '{letter}' picked")
18
- if letter in session["word"]:
19
- session["correct_letters"].append(letter)
20
- else:
21
- session["missed_letters"].append(letter)
22
-
23
- hangman = "".join(
24
- [
25
- (letter if letter in session["correct_letters"] else "_")
26
- for letter in session["word"]
27
- ]
28
- )
29
- session["hangman"] = hangman
30
- logger.info("Session state updated")
31
- return session
32
-
33
-
34
- logging.basicConfig(level=logging.INFO)
35
- logger = logging.getLogger(__file__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/hf_utils.py DELETED
@@ -1,109 +0,0 @@
1
- import logging
2
- import re
3
- import string
4
-
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
6
-
7
- GEMMA_WORD_PATTERNS = [
8
- "(?<=\*)(.*?)(?=\*)",
9
- '(?<=")(.*?)(?=")',
10
- ]
11
-
12
-
13
- def query_hf(
14
- query: str,
15
- model: AutoModelForCausalLM,
16
- tokenizer: AutoTokenizer,
17
- generation_config: dict,
18
- device: str,
19
- ) -> str:
20
- """Queries an LLM model using the Vertex AI API.
21
-
22
- Args:
23
- query (str): Query sent to the Vertex API
24
- model (str): Model target by Vertex
25
- generation_config (dict): Configurations used by the model
26
-
27
- Returns:
28
- str: Vertex AI text response
29
- """
30
- generation_config = GenerationConfig(
31
- do_sample=True,
32
- max_new_tokens=generation_config["max_output_tokens"],
33
- top_k=generation_config["top_k"],
34
- top_p=generation_config["top_p"],
35
- temperature=generation_config["temperature"],
36
- )
37
-
38
- input_ids = tokenizer(query, return_tensors="pt").to(device)
39
- outputs = model.generate(**input_ids, generation_config=generation_config)
40
- outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
- outputs = outputs.replace(query, "")
42
- return outputs
43
-
44
-
45
- def query_word(
46
- category: str,
47
- model: AutoModelForCausalLM,
48
- tokenizer: AutoTokenizer,
49
- generation_config: dict,
50
- device: str,
51
- ) -> str:
52
- """Queries a word to be used for the hangman game.
53
-
54
- Args:
55
- category (str): Category used as source sample a word
56
- model (str): Model target by Vertex
57
- generation_config (dict): Configurations used by the model
58
-
59
- Returns:
60
- str: Queried word
61
- """
62
- logger.info(f"Quering word for category: '{category}'...")
63
- query = f"Name a single existing {category}."
64
-
65
- matched_word = ""
66
- while not matched_word:
67
- word = query_hf(query, model, tokenizer, generation_config, device)
68
-
69
- # Extract word of interest from Gemma's output
70
- for pattern in GEMMA_WORD_PATTERNS:
71
- matched_words = re.findall(rf"{pattern}", word)
72
- matched_words = [x for x in matched_words if x != ""]
73
- if matched_words:
74
- matched_word = matched_words[-1]
75
-
76
- matched_word = matched_word.translate(str.maketrans("", "", string.punctuation))
77
- matched_word = matched_word.lower()
78
-
79
- logger.info("Word queried successful")
80
- return matched_word
81
-
82
-
83
- def query_hint(
84
- word: str,
85
- model: AutoModelForCausalLM,
86
- tokenizer: AutoTokenizer,
87
- generation_config: dict,
88
- device: str,
89
- ) -> str:
90
- """Queries a hint for the hangman game.
91
-
92
- Args:
93
- word (str): Word used as source to create the hint
94
- model (str): Model target by Vertex
95
- generation_config (dict): Configurations used by the model
96
-
97
- Returns:
98
- str: Queried hint
99
- """
100
- logger.info(f"Quering hint for word: '{word}'...")
101
- query = f"Describe the word '{word}' without mentioning it."
102
- hint = query_hf(query, model, tokenizer, generation_config, device)
103
- hint = re.sub(re.escape(word), "***", hint, flags=re.IGNORECASE)
104
- logger.info("Hint queried successful")
105
- return hint
106
-
107
-
108
- logging.basicConfig(level=logging.INFO)
109
- logger = logging.getLogger(__file__)