File size: 6,711 Bytes
04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb 50b2941 04e9ddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import pixeltable as pxt
import numpy as np
from datetime import datetime
from pixeltable.functions.huggingface import sentence_transformer
from pixeltable.functions import openai
import os
import getpass
# Store API keys
if 'OPENAI_API_KEY' not in os.environ:
os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API key:')
# Initialize Pixeltable
pxt.drop_dir('story_builder', force=True)
pxt.create_dir('story_builder')
# Create embedding function
@pxt.expr_udf
def embed_text(text: str) -> np.ndarray:
return sentence_transformer(text, model_id='all-MiniLM-L6-v2')
# Create a table to store story contributions
story_table = pxt.create_table(
'story_builder.contributions',
{
'contributor': pxt.StringType(),
'content': pxt.StringType(),
'timestamp': pxt.TimestampType(),
'cumulative_story': pxt.StringType()
}
)
# Add an embedding index to the content column
story_table.add_embedding_index('content', string_embed=embed_text)
@pxt.udf
def generate_summary(story: str) -> list[dict]:
system_msg = "You are an expert summarizer. Provide a concise summary of the given story, highlighting key plot points and themes."
user_msg = f"Story: {story}\n\nSummarize this story:"
return [
{'role': 'system', 'content': system_msg},
{'role': 'user', 'content': user_msg}
]
story_table['summary_prompt'] = generate_summary(story_table.cumulative_story)
story_table['summary_response'] = openai.chat_completions(
messages=story_table.summary_prompt,
model='gpt-3.5-turbo',
max_tokens=200
)
@pxt.udf
def generate_continuation(context: str) -> list[dict]:
system_msg = "You are a creative writer. Continue the story based on the given context. Write a paragraph that logically follow the provided content."
user_msg = f"Context: {context}\n\nContinue the story:"
return [
{'role': 'system', 'content': system_msg},
{'role': 'user', 'content': user_msg}
]
story_table['continuation_prompt'] = generate_continuation(story_table.cumulative_story)
story_table['continuation_response'] = openai.chat_completions(
messages=story_table.continuation_prompt,
model='gpt-3.5-turbo',
max_tokens=50
)
# Function to get the current cumulative story
def get_current_story():
latest_entry = story_table.tail(1)
if len(latest_entry) > 0:
return latest_entry['cumulative_story'][0]
return ""
# Functions for Gradio interface
def add_contribution(contributor, content):
current_story = get_current_story()
new_cumulative_story = current_story + " " + content if current_story else content
story_table.insert([{
'contributor': contributor,
'content': content,
'timestamp': datetime.now(),
'cumulative_story': new_cumulative_story
}])
return "Contribution added successfully!", new_cumulative_story
def get_similar_parts(query, num_results=5):
sim = story_table.content.similarity(query)
results = story_table.order_by(sim, asc=False).limit(num_results).select(story_table.content, story_table.contributor, sim=sim).collect()
return results.to_pandas()
def generate_next_part():
continuation = story_table.select(continuation=story_table.continuation_response.choices[0].message.content).tail(1)['continuation'][0]
return continuation
def summarize_story():
summary = story_table.select(summary=story_table.summary_response.choices[0].message.content).tail(1)['summary'][0]
return summary
# Gradio interface
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.HTML(
"""
<div style="text-align: left; margin-bottom: 1rem;">
<img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/source/data/pixeltable-logo-large.png" alt="Pixeltable" style="max-width: 150px;" />
</div>
"""
)
gr.Markdown(
"""
# 📚 Collaborative Story Builder
Welcome to the Collaborative Story Builder! This app allows multiple users to contribute to a story,
building it incrementally. Pixeltable manages the data, enables similarity search, and helps generate
continuations and summaries.
"""
)
with gr.Tabs():
with gr.TabItem("Contribute"):
with gr.Row():
with gr.Column(scale=2):
contributor = gr.Textbox(label="Your Name")
content = gr.Textbox(label="Your Contribution", lines=5)
submit_btn = gr.Button("Submit Contribution", variant="primary")
with gr.Column(scale=3):
status = gr.Textbox(label="Status")
current_story = gr.Textbox(label="Current Story", lines=10, interactive=False)
with gr.TabItem("Search & Generate"):
with gr.Row():
with gr.Column():
search_query = gr.Textbox(label="Search Current Contributions")
num_results = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Results")
search_btn = gr.Button("Search", variant="secondary")
search_results = gr.Dataframe(
headers=["Content", "Contributor", "Similarity"],
label="Similar Parts"
)
with gr.Column():
generate_btn = gr.Button("Generate Next Part", variant="primary")
generated_part = gr.Textbox(label="Generated Continuation", lines=5)
with gr.TabItem("Summary"):
summarize_btn = gr.Button("Summarize Story", variant="primary")
summary = gr.Textbox(label="Story Summary", lines=8)
submit_btn.click(add_contribution, inputs=[contributor, content], outputs=[status, current_story])
search_btn.click(get_similar_parts, inputs=[search_query, num_results], outputs=[search_results])
generate_btn.click(generate_next_part, outputs=[generated_part])
summarize_btn.click(summarize_story, outputs=[summary])
gr.HTML(
"""
<div style="text-align: center; margin-top: 1rem; padding-top: 1rem; border-top: 1px solid #ccc;">
<p style="margin: 0; color: #666; font-size: 0.8em;">
Powered by <a href="https://github.com/pixeltable/pixeltable" target="_blank" style="color: #F25022; text-decoration: none;">Pixeltable</a>
| <a href="https://github.com/pixeltable/pixeltable" target="_blank" style="color: #666; text-decoration: none;">GitHub</a>
</p>
</div>
"""
)
if __name__ == "__main__":
demo.launch() |