Updated mistral model
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import
|
|
|
3 |
|
4 |
# Set up the Streamlit app
|
5 |
st.title("π§ββοΈ Magic Story Buddy π")
|
@@ -8,9 +9,11 @@ st.markdown("Let's create a magical story just for you!")
|
|
8 |
# Initialize the model
|
9 |
@st.cache_resource
|
10 |
def load_model():
|
11 |
-
|
|
|
|
|
12 |
|
13 |
-
model = load_model()
|
14 |
|
15 |
# User input
|
16 |
child_name = st.text_input("What's your name, young storyteller?")
|
@@ -21,15 +24,29 @@ story_theme = st.selectbox("What would you like your story to be about?",
|
|
21 |
story_length = st.slider("How long should the story be?", 50, 200, 100)
|
22 |
include_moral = st.checkbox("Include a moral lesson?")
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
if st.button("Create My Story!"):
|
25 |
if child_name and story_theme:
|
26 |
# Construct the prompt
|
27 |
-
prompt = f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
if include_moral:
|
29 |
prompt += "This story teaches us that "
|
30 |
|
31 |
# Generate the story
|
32 |
-
story =
|
33 |
|
34 |
# Display the story
|
35 |
st.markdown("## Your Magical Story")
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import torch
|
4 |
|
5 |
# Set up the Streamlit app
|
6 |
st.title("π§ββοΈ Magic Story Buddy π")
|
|
|
9 |
# Initialize the model
|
10 |
@st.cache_resource
|
11 |
def load_model():
|
12 |
+
model = AutoModelForCausalLM.from_pretrained("ajibawa-2023/Young-Children-Storyteller-Mistral-7B", torch_dtype=torch.float16)
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained("ajibawa-2023/Young-Children-Storyteller-Mistral-7B")
|
14 |
+
return model, tokenizer
|
15 |
|
16 |
+
model, tokenizer = load_model()
|
17 |
|
18 |
# User input
|
19 |
child_name = st.text_input("What's your name, young storyteller?")
|
|
|
24 |
story_length = st.slider("How long should the story be?", 50, 200, 100)
|
25 |
include_moral = st.checkbox("Include a moral lesson?")
|
26 |
|
27 |
+
def generate_story(prompt, max_length=500):
|
28 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
29 |
+
outputs = model.generate(**inputs, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=0.7)
|
30 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
31 |
+
|
32 |
if st.button("Create My Story!"):
|
33 |
if child_name and story_theme:
|
34 |
# Construct the prompt
|
35 |
+
prompt = f"""Create a short children's story with the following details:
|
36 |
+
- Main character: {child_name}
|
37 |
+
- Theme: {story_theme}
|
38 |
+
- Length: About {story_length} words
|
39 |
+
- Audience: Children aged 5-10
|
40 |
+
- Tone: Friendly, educational, and imaginative
|
41 |
+
|
42 |
+
Story:
|
43 |
+
Once upon a time, in a {story_theme.lower()}, there was a brave child named {child_name}. """
|
44 |
+
|
45 |
if include_moral:
|
46 |
prompt += "This story teaches us that "
|
47 |
|
48 |
# Generate the story
|
49 |
+
story = generate_story(prompt, max_length=story_length)
|
50 |
|
51 |
# Display the story
|
52 |
st.markdown("## Your Magical Story")
|