Ashhar
commited on
Commit
·
a94a1dc
1
Parent(s):
f97efd9
changes to context window + image prompt
Browse files
app.py
CHANGED
@@ -12,43 +12,51 @@ from gradio_client import Client
|
|
12 |
from dotenv import load_dotenv
|
13 |
load_dotenv()
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
from groq import Groq
|
17 |
-
client = Groq(
|
18 |
-
api_key=os.environ.get("GROQ_API_KEY"),
|
19 |
-
)
|
20 |
|
21 |
-
MODEL = "llama-3.1-70b-versatile"
|
22 |
JSON_SEPARATOR = ">>>>"
|
23 |
|
24 |
|
25 |
-
tokenizer = AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
|
26 |
-
|
27 |
-
|
28 |
def countTokens(text):
|
29 |
-
|
30 |
tokens = tokenizer.encode(text, add_special_tokens=False)
|
31 |
-
# Return the number of tokens
|
32 |
return len(tokens)
|
33 |
|
34 |
|
35 |
SYSTEM_MSG = f"""
|
36 |
You're an storytelling assistant who guides users through four phases of narrative development, helping them craft compelling personal or professional stories. The story created should be in simple language, yet evoke great emotions.
|
37 |
-
Ask one question at a time, give the options in a well formatted manner in different lines
|
38 |
If your response has number of options to choose from, only then append your final response with this exact keyword "{JSON_SEPARATOR}", and only after this, append with the JSON of options to choose from. The JSON should be of the format:
|
39 |
{{
|
40 |
"options": [
|
41 |
{{ "id": "1", "label": "Option 1"}},
|
42 |
-
{{ "id": "2", "label": "Option 2"}}
|
43 |
]
|
44 |
}}
|
45 |
Do not write "Choose one of the options below:"
|
46 |
-
Keep options to less than 9
|
|
|
47 |
|
48 |
# Tier 1: Story Creation
|
49 |
You initiate the storytelling process through a series of engaging prompts:
|
50 |
Story Origin:
|
51 |
-
Asks users to choose between personal anecdotes or adapting a well-known story (creating a story database here of well-known
|
52 |
|
53 |
Story Use Case:
|
54 |
Asks users to define the purpose of building a story (e.g., profile story, for social media content).
|
@@ -146,6 +154,37 @@ def pprint(log: str):
|
|
146 |
|
147 |
pprint("\n")
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
def __isInvalidResponse(response: str):
|
151 |
# new line followed by small case char
|
@@ -161,7 +200,7 @@ def __isInvalidResponse(response: str):
|
|
161 |
return True
|
162 |
|
163 |
# json response without json separator
|
164 |
-
if ('
|
165 |
return True
|
166 |
|
167 |
|
@@ -180,23 +219,60 @@ def __isStringNumber(s: str) -> bool:
|
|
180 |
return False
|
181 |
|
182 |
|
183 |
-
def
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
if (
|
186 |
__matchingKeywordsCount(
|
187 |
["adapt", "profile", "social media", "purpose", "use case"],
|
188 |
-
|
189 |
) > 2
|
190 |
and not __isStringNumber(prompt)
|
191 |
-
and
|
|
|
192 |
):
|
193 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
if __matchingKeywordsCount(
|
196 |
-
["
|
197 |
-
|
198 |
) > 0:
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
|
202 |
def __resetButtonState():
|
@@ -210,6 +286,9 @@ def __setStartMsg(msg):
|
|
210 |
if "messages" not in st.session_state:
|
211 |
st.session_state.messages = []
|
212 |
|
|
|
|
|
|
|
213 |
if "buttonValue" not in st.session_state:
|
214 |
__resetButtonState()
|
215 |
|
@@ -217,19 +296,33 @@ if "startMsg" not in st.session_state:
|
|
217 |
st.session_state.startMsg = ""
|
218 |
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
def predict(prompt):
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
])
|
226 |
-
historyFormatted.append({"role": "user", "content": prompt })
|
227 |
-
contextSize = countTokens(str(historyFormatted))
|
228 |
-
pprint(f"{contextSize=}")
|
229 |
|
230 |
response = client.chat.completions.create(
|
231 |
-
model=
|
232 |
-
messages=
|
233 |
temperature=0.8,
|
234 |
max_tokens=4000,
|
235 |
stream=True
|
@@ -245,13 +338,13 @@ def predict(prompt):
|
|
245 |
|
246 |
def generateImage(prompt: str):
|
247 |
pprint(f"imagePrompt={prompt}")
|
248 |
-
|
249 |
-
result =
|
250 |
prompt=prompt,
|
251 |
seed=0,
|
252 |
randomize_seed=True,
|
253 |
-
width=
|
254 |
-
height=
|
255 |
num_inference_steps=4,
|
256 |
api_name="/infer"
|
257 |
)
|
@@ -321,14 +414,26 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
|
|
321 |
[response, jsonStr] = responseParts
|
322 |
|
323 |
imagePath = None
|
|
|
324 |
try:
|
325 |
-
imagePrompt =
|
326 |
if imagePrompt:
|
327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
(imagePath, seed) = generateImage(imagePrompt)
|
329 |
imageContainer.image(imagePath)
|
330 |
except Exception as e:
|
331 |
pprint(e)
|
|
|
332 |
|
333 |
if jsonStr:
|
334 |
try:
|
|
|
12 |
from dotenv import load_dotenv
|
13 |
load_dotenv()
|
14 |
|
15 |
+
useGpt4 = os.environ.get("USE_GPT_4") == "1"
|
16 |
+
|
17 |
+
if useGpt4:
|
18 |
+
from openai import OpenAI
|
19 |
+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
20 |
+
MODEL = "gpt-4o-mini"
|
21 |
+
MAX_CONTEXT = 128000
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained("Xenova/gpt-4o")
|
23 |
+
else:
|
24 |
+
from groq import Groq
|
25 |
+
client = Groq(
|
26 |
+
api_key=os.environ.get("GROQ_API_KEY"),
|
27 |
+
)
|
28 |
+
MODEL = "llama-3.1-70b-versatile"
|
29 |
+
MAX_CONTEXT = 8000
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
|
31 |
|
|
|
|
|
|
|
|
|
32 |
|
|
|
33 |
JSON_SEPARATOR = ">>>>"
|
34 |
|
35 |
|
|
|
|
|
|
|
36 |
def countTokens(text):
|
37 |
+
text = str(text)
|
38 |
tokens = tokenizer.encode(text, add_special_tokens=False)
|
|
|
39 |
return len(tokens)
|
40 |
|
41 |
|
42 |
SYSTEM_MSG = f"""
|
43 |
You're an storytelling assistant who guides users through four phases of narrative development, helping them craft compelling personal or professional stories. The story created should be in simple language, yet evoke great emotions.
|
44 |
+
Ask one question at a time, give the options in a numbered and well formatted manner in different lines
|
45 |
If your response has number of options to choose from, only then append your final response with this exact keyword "{JSON_SEPARATOR}", and only after this, append with the JSON of options to choose from. The JSON should be of the format:
|
46 |
{{
|
47 |
"options": [
|
48 |
{{ "id": "1", "label": "Option 1"}},
|
49 |
+
{{ "id": "2", "label": "Option 2"}}
|
50 |
]
|
51 |
}}
|
52 |
Do not write "Choose one of the options below:"
|
53 |
+
Keep options to less than 9.
|
54 |
+
Summarise options chosen so far in each step.
|
55 |
|
56 |
# Tier 1: Story Creation
|
57 |
You initiate the storytelling process through a series of engaging prompts:
|
58 |
Story Origin:
|
59 |
+
Asks users to choose between personal anecdotes or adapting a well-known story (creating a story database here of well-known stories to choose from).
|
60 |
|
61 |
Story Use Case:
|
62 |
Asks users to define the purpose of building a story (e.g., profile story, for social media content).
|
|
|
154 |
|
155 |
pprint("\n")
|
156 |
|
157 |
+
st.markdown(
|
158 |
+
"""
|
159 |
+
<style>
|
160 |
+
@keyframes blinker {
|
161 |
+
0% {
|
162 |
+
opacity: 1;
|
163 |
+
}
|
164 |
+
50% {
|
165 |
+
opacity: 0.2;
|
166 |
+
}
|
167 |
+
100% {
|
168 |
+
opacity: 1;
|
169 |
+
}
|
170 |
+
}
|
171 |
+
|
172 |
+
.blinking {
|
173 |
+
animation: blinker 3s ease-out infinite;
|
174 |
+
}
|
175 |
+
|
176 |
+
.code {
|
177 |
+
color: green;
|
178 |
+
border-radius: 3px;
|
179 |
+
padding: 2px 4px; /* Padding around the text */
|
180 |
+
font-family: 'Courier New', Courier, monospace; /* Monospace font */
|
181 |
+
}
|
182 |
+
|
183 |
+
</style>
|
184 |
+
""",
|
185 |
+
unsafe_allow_html=True
|
186 |
+
)
|
187 |
+
|
188 |
|
189 |
def __isInvalidResponse(response: str):
|
190 |
# new line followed by small case char
|
|
|
200 |
return True
|
201 |
|
202 |
# json response without json separator
|
203 |
+
if ('{\n "options"' in response) and (JSON_SEPARATOR not in response):
|
204 |
return True
|
205 |
|
206 |
|
|
|
219 |
return False
|
220 |
|
221 |
|
222 |
+
def __getImagePromptDetails(prompt: str, response: str):
|
223 |
+
regex = r'[^a-z0-9 \n\.\-]|((the) +)'
|
224 |
+
|
225 |
+
cleanedResponse = re.sub(regex, '', response.lower())
|
226 |
+
pprint(f"{cleanedResponse=}")
|
227 |
+
|
228 |
+
cleanedPrompt = re.sub(regex, '', prompt.lower())
|
229 |
+
pprint(f"{cleanedPrompt=}")
|
230 |
+
|
231 |
if (
|
232 |
__matchingKeywordsCount(
|
233 |
["adapt", "profile", "social media", "purpose", "use case"],
|
234 |
+
cleanedResponse
|
235 |
) > 2
|
236 |
and not __isStringNumber(prompt)
|
237 |
+
and cleanedPrompt in cleanedResponse
|
238 |
+
and "story so far" not in cleanedResponse
|
239 |
):
|
240 |
+
return (
|
241 |
+
f'''
|
242 |
+
Subject: {prompt}.
|
243 |
+
Style: Fantastical, in a storybook, surreal, bokeh
|
244 |
+
''',
|
245 |
+
"Painting your character ..."
|
246 |
+
)
|
247 |
+
|
248 |
+
'''
|
249 |
+
Mood: ethereal lighting that emphasizes the fantastical nature of the scene.
|
250 |
+
|
251 |
+
storybook style
|
252 |
+
|
253 |
+
4d model, unreal engine
|
254 |
+
|
255 |
+
Alejandro Bursido
|
256 |
+
|
257 |
+
vintage, nostalgic
|
258 |
+
|
259 |
+
Dreamlike, Mystical, Fantastical, Charming
|
260 |
+
'''
|
261 |
|
262 |
if __matchingKeywordsCount(
|
263 |
+
["tier 2", "tier-2"],
|
264 |
+
cleanedResponse
|
265 |
) > 0:
|
266 |
+
possibleStoryEndIdx = [response.find("tier 2"), response.find("tier-2")]
|
267 |
+
storyEndIdx = max(possibleStoryEndIdx)
|
268 |
+
relevantResponse = response[:storyEndIdx]
|
269 |
+
pprint(f"{relevantResponse=}")
|
270 |
+
return (
|
271 |
+
f"photo of a scene from this text: {relevantResponse}",
|
272 |
+
"Imagining your scene (beta) ..."
|
273 |
+
)
|
274 |
+
|
275 |
+
return (None, None)
|
276 |
|
277 |
|
278 |
def __resetButtonState():
|
|
|
286 |
if "messages" not in st.session_state:
|
287 |
st.session_state.messages = []
|
288 |
|
289 |
+
if "history" not in st.session_state:
|
290 |
+
st.session_state.history = []
|
291 |
+
|
292 |
if "buttonValue" not in st.session_state:
|
293 |
__resetButtonState()
|
294 |
|
|
|
296 |
st.session_state.startMsg = ""
|
297 |
|
298 |
|
299 |
+
def __getChatMessages(prompt: str):
|
300 |
+
st.session_state.history.append({
|
301 |
+
"role": "user",
|
302 |
+
"content": prompt
|
303 |
+
})
|
304 |
+
|
305 |
+
def getContextSize():
|
306 |
+
currContextSize = countTokens(SYSTEM_MSG) + countTokens(st.session_state.history) + 100
|
307 |
+
pprint(f"{currContextSize=}")
|
308 |
+
return currContextSize
|
309 |
+
|
310 |
+
while getContextSize() > MAX_CONTEXT:
|
311 |
+
pprint("Context size exceeded, removing first message")
|
312 |
+
st.session_state.history.pop(0)
|
313 |
+
|
314 |
+
return st.session_state.history
|
315 |
+
|
316 |
+
|
317 |
def predict(prompt):
|
318 |
+
messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}]
|
319 |
+
messagesFormatted.extend(__getChatMessages(prompt))
|
320 |
+
contextSize = countTokens(messagesFormatted)
|
321 |
+
pprint(f"{contextSize=} | {MODEL}")
|
|
|
|
|
|
|
|
|
322 |
|
323 |
response = client.chat.completions.create(
|
324 |
+
model=MODEL,
|
325 |
+
messages=messagesFormatted,
|
326 |
temperature=0.8,
|
327 |
max_tokens=4000,
|
328 |
stream=True
|
|
|
338 |
|
339 |
def generateImage(prompt: str):
|
340 |
pprint(f"imagePrompt={prompt}")
|
341 |
+
fluxClient = Client("black-forest-labs/FLUX.1-schnell")
|
342 |
+
result = fluxClient.predict(
|
343 |
prompt=prompt,
|
344 |
seed=0,
|
345 |
randomize_seed=True,
|
346 |
+
width=1024,
|
347 |
+
height=768,
|
348 |
num_inference_steps=4,
|
349 |
api_name="/infer"
|
350 |
)
|
|
|
414 |
[response, jsonStr] = responseParts
|
415 |
|
416 |
imagePath = None
|
417 |
+
imageContainer = st.empty()
|
418 |
try:
|
419 |
+
(imagePrompt, loaderText) = __getImagePromptDetails(prompt, response)
|
420 |
if imagePrompt:
|
421 |
+
imgContainer = imageContainer.container()
|
422 |
+
imgContainer.write(
|
423 |
+
f"""
|
424 |
+
<div class='blinking code'>
|
425 |
+
{loaderText}
|
426 |
+
</div>
|
427 |
+
""",
|
428 |
+
unsafe_allow_html=True
|
429 |
+
)
|
430 |
+
# imgContainer.markdown(f"`{loaderText}`")
|
431 |
+
imgContainer.image(IMAGE_LOADER)
|
432 |
(imagePath, seed) = generateImage(imagePrompt)
|
433 |
imageContainer.image(imagePath)
|
434 |
except Exception as e:
|
435 |
pprint(e)
|
436 |
+
imageContainer.empty()
|
437 |
|
438 |
if jsonStr:
|
439 |
try:
|