Chain-of-Image / app.py
fxmeng's picture
Update app.py
d25ad55 verified
raw
history blame
3.7 kB
import gradio as gr
import time
import base64
from openai import OpenAI
def wait_on_run(run, client, thread):
while run.status == "queued" or run.status == "in_progress":
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run
def GenerateImageByCode(client, message, code_prompt):
assistant = client.beta.assistants.create(
name = "Chain of Image",
instructions=code_prompt,
model="gpt-4-1106-preview",
tools=[{"type": "code_interpreter"}]
)
thread = client.beta.threads.create()
client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=message,
)
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
)
run = wait_on_run(run, client, thread)
run_steps = client.beta.threads.runs.steps.list(thread_id=thread.id, run_id=run.id, order="asc")
for data in run_steps.model_dump()['data']:
if "tool_calls" in data['step_details']:
code = data['step_details']['tool_calls'][0]['code_interpreter']['input']
if 'image' in data['step_details']['tool_calls'][0]['code_interpreter']['outputs'][0].keys():
image_id = data['step_details']['tool_calls'][0]['code_interpreter']['outputs'][0]['image']['file_id']
assert image_id
image_bytes = client.files.with_raw_response.content(image_id).content
with open(f'{image_id}.png', 'wb') as f:
f.write(image_bytes)
base64_image = base64.b64encode(image_bytes).decode('utf-8')
return f"{image_id}.png", base64_image
def visual_question_answer(client, base64_image, question, vqa_prompt, max_tokens=256):
response = client.chat.completions.create(model="gpt-4-vision-preview",
messages=[
{"role": "system", "content": vqa_prompt},
{"role": "user", "content": [
{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{base64_image}",},},
{"type": "text", "text": f"Question:\n{question}\nAnswer:\n"},],},
], max_tokens=max_tokens,)
return response.choices[0].message.content
def chain_of_images(message, history, code_prompt, vqa_prompt, api_token, max_tokens):
client = OpenAI(api_key=api_token)
if len(history):
return visual_question_answer(client, history[0][1][1], message, vqa_prompt, max_tokens=max_tokens)
else:
return GenerateImageByCode(client, message, code_prompt)
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
demo = gr.ChatInterface(chain_of_images,
additional_inputs=[
gr.Textbox("You are a research drawing assistant. Your primary role is to help visualize questions posed by users. Instead of directly answering questions, you will use code to invoke the most suitable toolkit, transforming these questions into images. This helps users quickly understand the question and find answers through visualization. You should prioritize clarity and effectiveness in your visual representations, ensuring that complex scientific or technical concepts are made accessible and comprehensible through your drawings.", label="Code Interpreter Prompt"),
gr.Textbox("You are a visual thinking expert. Your primary role is to answer questions about an image posed by users.", label="VQA Prompt"),
gr.Textbox(label="API Key"),
gr.Slider(32, 128),
],
).queue()
if __name__ == "__main__":
demo.launch()