dtm95 commited on
Commit
5df603e
Β·
verified Β·
1 Parent(s): cc6a5c4

Updated mistral model

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
3
 
4
  # Set up the Streamlit app
5
  st.title("πŸ§šβ€β™€οΈ Magic Story Buddy πŸ“š")
@@ -8,9 +9,11 @@ st.markdown("Let's create a magical story just for you!")
8
  # Initialize the model
9
  @st.cache_resource
10
  def load_model():
11
- return pipeline("text-generation", model="ajibawa-2023/Young-Children-Storyteller-Mistral-7B")
 
 
12
 
13
- model = load_model()
14
 
15
  # User input
16
  child_name = st.text_input("What's your name, young storyteller?")
@@ -21,15 +24,29 @@ story_theme = st.selectbox("What would you like your story to be about?",
21
  story_length = st.slider("How long should the story be?", 50, 200, 100)
22
  include_moral = st.checkbox("Include a moral lesson?")
23
 
 
 
 
 
 
24
  if st.button("Create My Story!"):
25
  if child_name and story_theme:
26
  # Construct the prompt
27
- prompt = f"[CHILDREN'S STORY] Once upon a time, in a {story_theme.lower()}, there was a brave child named {child_name}. "
 
 
 
 
 
 
 
 
 
28
  if include_moral:
29
  prompt += "This story teaches us that "
30
 
31
  # Generate the story
32
- story = model(prompt, max_length=story_length, num_return_sequences=1)[0]['generated_text']
33
 
34
  # Display the story
35
  st.markdown("## Your Magical Story")
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
  # Set up the Streamlit app
6
  st.title("πŸ§šβ€β™€οΈ Magic Story Buddy πŸ“š")
 
9
  # Initialize the model
10
  @st.cache_resource
11
  def load_model():
12
+ model = AutoModelForCausalLM.from_pretrained("ajibawa-2023/Young-Children-Storyteller-Mistral-7B", torch_dtype=torch.float16)
13
+ tokenizer = AutoTokenizer.from_pretrained("ajibawa-2023/Young-Children-Storyteller-Mistral-7B")
14
+ return model, tokenizer
15
 
16
+ model, tokenizer = load_model()
17
 
18
  # User input
19
  child_name = st.text_input("What's your name, young storyteller?")
 
24
  story_length = st.slider("How long should the story be?", 50, 200, 100)
25
  include_moral = st.checkbox("Include a moral lesson?")
26
 
27
+ def generate_story(prompt, max_length=500):
28
+ inputs = tokenizer(prompt, return_tensors="pt")
29
+ outputs = model.generate(**inputs, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=0.7)
30
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+
32
  if st.button("Create My Story!"):
33
  if child_name and story_theme:
34
  # Construct the prompt
35
+ prompt = f"""Create a short children's story with the following details:
36
+ - Main character: {child_name}
37
+ - Theme: {story_theme}
38
+ - Length: About {story_length} words
39
+ - Audience: Children aged 5-10
40
+ - Tone: Friendly, educational, and imaginative
41
+
42
+ Story:
43
+ Once upon a time, in a {story_theme.lower()}, there was a brave child named {child_name}. """
44
+
45
  if include_moral:
46
  prompt += "This story teaches us that "
47
 
48
  # Generate the story
49
+ story = generate_story(prompt, max_length=story_length)
50
 
51
  # Display the story
52
  st.markdown("## Your Magical Story")