Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -61,8 +61,8 @@ def text2story(text):
|
|
61 |
genre = labels[predicted_class]
|
62 |
|
63 |
# Then, generate a story based on the predicted genre
|
64 |
-
story_generator = AutoModelForCausalLM.from_pretrained("gpt2")
|
65 |
-
story_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
66 |
|
67 |
# Create a genre-specific prompt
|
68 |
prompt = f"Write a {genre.lower()} story about: {text}\n\nStory:"
|
@@ -85,6 +85,25 @@ def text2story(text):
|
|
85 |
story = story_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
86 |
final_story = story.replace(prompt, "").strip()
|
87 |
return final_story
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def load_css(css_file):
|
90 |
with open(css_file) as f:
|
|
|
61 |
genre = labels[predicted_class]
|
62 |
|
63 |
# Then, generate a story based on the predicted genre
|
64 |
+
story_generator = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
65 |
+
story_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
66 |
|
67 |
# Create a genre-specific prompt
|
68 |
prompt = f"Write a {genre.lower()} story about: {text}\n\nStory:"
|
|
|
85 |
story = story_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
86 |
final_story = story.replace(prompt, "").strip()
|
87 |
return final_story
|
88 |
+
except Exception as e:
|
89 |
+
# just Fallback only used openai-community/gpt2 if the advanced one fails
|
90 |
+
fallback_generator = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
91 |
+
fallback_prompt = f"{text}"
|
92 |
+
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
93 |
+
inputs = tokenizer(fallback_prompt, return_tensors="pt")
|
94 |
+
fallback_story = fallback_generator.generate(
|
95 |
+
inputs.input_ids,
|
96 |
+
min_length=50,
|
97 |
+
max_new_tokens=100,
|
98 |
+
temperature=0.7,
|
99 |
+
top_p=0.9,
|
100 |
+
top_k=40,
|
101 |
+
repetition_penalty=1.2,
|
102 |
+
do_sample=True,
|
103 |
+
pad_token_id=tokenizer.eos_token_id
|
104 |
+
)
|
105 |
+
fallback_story = tokenizer.decode(fallback_story[0], skip_special_tokens=True)
|
106 |
+
return fallback_story
|
107 |
|
108 |
def load_css(css_file):
|
109 |
with open(css_file) as f:
|