DarwinAnim8or commited on
Commit
ffac7c5
1 Parent(s): fca49da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -12,7 +12,7 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
12
  if tokenizer.pad_token_id is None:
13
  tokenizer.pad_token_id = tokenizer.eos_token_id
14
 
15
- def generate_story(prompt, max_length=200):
16
  """Generates a story continuation from a given prompt."""
17
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
18
 
@@ -25,7 +25,7 @@ def generate_story(prompt, max_length=200):
25
  do_sample=True,
26
  top_k=50,
27
  top_p=0.95,
28
- temperature=0.8, # Control randomness (higher = more creative)
29
  )
30
 
31
  # Decode the generated story
@@ -34,16 +34,17 @@ def generate_story(prompt, max_length=200):
34
 
35
  # Gradio Interface
36
  with gr.Blocks() as demo:
37
- gr.Markdown("## Storyteller: Generate a story from a prompt!")
38
  prompt_input = gr.Textbox(label="Enter your story prompt:")
39
  story_output = gr.Textbox(label="Generated story:")
40
  max_length_slider = gr.Slider(minimum=50, maximum=500, value=200, step=10, label="Max Story Length")
 
41
  generate_button = gr.Button("Generate Story")
42
 
43
  # Event handling
44
  generate_button.click(
45
  fn=generate_story,
46
- inputs=[prompt_input, max_length_slider],
47
  outputs=story_output
48
  )
49
 
 
12
  if tokenizer.pad_token_id is None:
13
  tokenizer.pad_token_id = tokenizer.eos_token_id
14
 
15
+ def generate_story(prompt, max_length=200, temp=0.3):
16
  """Generates a story continuation from a given prompt."""
17
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
18
 
 
25
  do_sample=True,
26
  top_k=50,
27
  top_p=0.95,
28
+ temperature=temp, # Control randomness (higher = more creative)
29
  )
30
 
31
  # Decode the generated story
 
34
 
35
  # Gradio Interface
36
  with gr.Blocks() as demo:
37
+ gr.Markdown("## 'NoSleep' Storyteller: Generate a story from a prompt!")
38
  prompt_input = gr.Textbox(label="Enter your story prompt:")
39
  story_output = gr.Textbox(label="Generated story:")
40
  max_length_slider = gr.Slider(minimum=50, maximum=500, value=200, step=10, label="Max Story Length")
41
+ temp_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Temperature (randomness)")
42
  generate_button = gr.Button("Generate Story")
43
 
44
  # Event handling
45
  generate_button.click(
46
  fn=generate_story,
47
+ inputs=[prompt_input, max_length_slider, temp_slider],
48
  outputs=story_output
49
  )
50