davidizzle commited on
Commit
1efea19
Β·
0 Parent(s):

Added Gemma model, model selection, inference, sliders and UI

Browse files
Files changed (5) hide show
  1. .gitignore +4 -0
  2. README.md +15 -0
  3. app.py +99 -0
  4. requirements.txt +4 -0
  5. utils.py +0 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .env
4
+ assets/*.gif
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # πŸ’Ž Gemma πŸ’Ž HF Spaces Demo
2
+
3
+ An interactive [Streamlit](https://streamlit.io) app to test [Gemma](https://huggingface.co/google/gemma-2b) models directly in your browser.
4
+
5
+ ## Features πŸš€
6
+
7
+ - Chat with the Gemma model (default: `google/gemma-2b`)
8
+ - Fast deploy to Hugging Face Spaces
9
+ - Easy to customize & extend
10
+
11
+ ## Setup πŸ“¦
12
+
13
+ ```bash
14
+ pip install -r requirements.txt
15
+ streamlit run app.py
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import base64
5
+
6
+ st.set_page_config(page_title="Gemma Demo", layout="wide")
7
+ # Model selection (STUBBED behavior)
8
+ model_option = st.selectbox(
9
+ "Choose a Gemma to reveal hidden truths:",
10
+ ["gemma-2b-it (Instruct)", "gemma-2b", "gemma-7b", "gemma-7b-it"],
11
+ index=0,
12
+ help="Stubbed selection – only gemma-2b-it will load for now."
13
+ )
14
+ st.markdown("<h1 style='text-align: center;'>Portal to Gemma</h1>", unsafe_allow_html=True)
15
+
16
+ # Load both GIFs in base64 format
17
+ def load_gif_base64(path):
18
+ with open(path, "rb") as f:
19
+ return base64.b64encode(f.read()).decode("utf-8")
20
+
21
+ still_gem_b64 = load_gif_base64("assets/stillGem.gif")
22
+ rotating_gem_b64 = load_gif_base64("assets/rotatingGem.gif")
23
+
24
+ # Placeholder for GIF HTML
25
+ gif_html = st.empty()
26
+ caption = st.empty()
27
+
28
+ # Initially show still gem
29
+ gif_html.markdown(
30
+ f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
31
+ unsafe_allow_html=True,
32
+ )
33
+
34
+ @st.cache_resource
35
+ def load_model():
36
+ model_id = "google/gemma-2b-it"
37
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_id,
40
+ device_map=None,
41
+ torch_dtype=torch.float32
42
+ )
43
+ model.to("cpu")
44
+ return tokenizer, model
45
+
46
+ tokenizer, model = load_model()
47
+ prompt = st.text_area("Enter your prompt:", "What is Gemma?")
48
+ # # Example prompt selector
49
+ # examples = {
50
+ # "🧠 Summary": "Summarize the history of AI in 5 bullet points.",
51
+ # "πŸ’» Code": "Write a Python function to sort a list using bubble sort.",
52
+ # "πŸ“œ Poem": "Write a haiku about large language models.",
53
+ # "πŸ€– Explain": "Explain what a transformer is in simple terms.",
54
+ # "πŸ” Fact": "Who won the FIFA World Cup in 2022?"
55
+ # }
56
+
57
+ # selected_example = st.selectbox("Choose a Gemma to consult:", list(examples.keys()) + ["✍️ Custom input"])
58
+ # Add before generation
59
+ col1, col2, col3 = st.columns(3)
60
+
61
+ with col1:
62
+ temperature = st.slider("Temperature", 0.1, 1.5, 1.0)
63
+
64
+ with col2:
65
+ max_tokens = st.slider("Max tokens", 50, 500, 100)
66
+
67
+ with col3:
68
+ top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95)
69
+ # if selected_example != "✍️ Custom input":
70
+ # prompt = examples[selected_example]
71
+ # else:
72
+ # prompt = st.text_area("Enter your prompt:")
73
+
74
+ if st.button("Generate"):
75
+ # Swap to rotating GIF
76
+ gif_html.markdown(
77
+ f"<div style='text-align:center;'><img src='data:image/gif;base64,{rotating_gem_b64}' width='300'></div>",
78
+ unsafe_allow_html=True,
79
+ )
80
+ caption.markdown("<p style='text-align: center;'>Gemma is thinking... πŸŒ€</p>", unsafe_allow_html=True)
81
+
82
+
83
+ # Generate text
84
+
85
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
86
+ with torch.no_grad():
87
+ outputs = model.generate(**inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)
88
+
89
+ # Back to still
90
+ gif_html.markdown(
91
+ f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
92
+ unsafe_allow_html=True,
93
+ )
94
+ caption.empty()
95
+
96
+
97
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+ st.markdown("### ✨ Output:")
99
+ st.write(result)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch
4
+ accelerate
utils.py ADDED
File without changes