slliac commited on
Commit
6256080
·
verified ·
1 Parent(s): fc91334

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -61,8 +61,8 @@ def text2story(text):
61
  genre = labels[predicted_class]
62
 
63
  # Then, generate a story based on the predicted genre
64
- story_generator = AutoModelForCausalLM.from_pretrained("gpt2")
65
- story_tokenizer = AutoTokenizer.from_pretrained("gpt2")
66
 
67
  # Create a genre-specific prompt
68
  prompt = f"Write a {genre.lower()} story about: {text}\n\nStory:"
@@ -85,6 +85,25 @@ def text2story(text):
85
  story = story_tokenizer.decode(outputs[0], skip_special_tokens=True)
86
  final_story = story.replace(prompt, "").strip()
87
  return final_story
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def load_css(css_file):
90
  with open(css_file) as f:
 
61
  genre = labels[predicted_class]
62
 
63
  # Then, generate a story based on the predicted genre
64
+ story_generator = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
65
+ story_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
66
 
67
  # Create a genre-specific prompt
68
  prompt = f"Write a {genre.lower()} story about: {text}\n\nStory:"
 
85
  story = story_tokenizer.decode(outputs[0], skip_special_tokens=True)
86
  final_story = story.replace(prompt, "").strip()
87
  return final_story
88
+ except Exception as e:
89
+ # just Fallback only used openai-community/gpt2 if the advanced one fails
90
+ fallback_generator = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
91
+ fallback_prompt = f"{text}"
92
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
93
+ inputs = tokenizer(fallback_prompt, return_tensors="pt")
94
+ fallback_story = fallback_generator.generate(
95
+ inputs.input_ids,
96
+ min_length=50,
97
+ max_new_tokens=100,
98
+ temperature=0.7,
99
+ top_p=0.9,
100
+ top_k=40,
101
+ repetition_penalty=1.2,
102
+ do_sample=True,
103
+ pad_token_id=tokenizer.eos_token_id
104
+ )
105
+ fallback_story = tokenizer.decode(fallback_story[0], skip_special_tokens=True)
106
+ return fallback_story
107
 
108
  def load_css(css_file):
109
  with open(css_file) as f: