Spaces:
Runtime error
Runtime error
import gradio as gr | |
import openai | |
import requests | |
import os | |
from dotenv import load_dotenv | |
import io | |
import sys | |
import json | |
import PIL | |
import time | |
from stability_sdk import client | |
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation | |
import markdown2 | |
title="najimino AI recipe generator" | |
inputs_label="どんな料理か教えてくれれば,新しいレシピを考えます" | |
outputs_label="najimino AIが返信をします" | |
visual_outputs_label="料理のイメージ" | |
description=""" | |
- ※入出力の文字数は最大1000文字程度までを目安に入力してください。回答に50秒くらいかかります. | |
""" | |
article = """ | |
""" | |
load_dotenv() | |
openai.api_key = os.getenv('OPENAI_API_KEY') | |
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443' | |
stability_api = client.StabilityInference( | |
key=os.getenv('STABILITY_KEY'), | |
verbose=True, | |
# engine="stable-diffusion-512-v2-1", | |
# engine="stable-diffusion-xl-beta-v2-2-2", | |
# engine="stable-diffusion-xl-1024-v0-9", | |
engine="stable-diffusion-xl-1024-v1-0", | |
# Available engines: stable-diffusion-v1 stable-diffusion-v1-5 stable-diffusion-512-v2-0 stable-diffusion-768-v2-0 | |
# stable-diffusion-512-v2-1 stable-diffusion-768-v2-1 stable-diffusion-xl-beta-v2-2-2 stable-inpainting-v1-0 stable-inpainting-512-v2-0 | |
) | |
MODEL = "gpt-4" | |
# MODEL = "gpt-3.5-turbo-16k" | |
# MODEL = "gpt-3.5-turbo-0613" | |
def get_filetext(filename, cache={}): | |
if filename in cache: | |
# キャッシュに保存されている場合は、キャッシュからファイル内容を取得する | |
return cache[filename] | |
else: | |
if not os.path.exists(filename): | |
raise ValueError(f"ファイル '{filename}' が見つかりませんでした") | |
with open(filename, "r") as f: | |
text = f.read() | |
# ファイル内容をキャッシュする | |
cache[filename] = text | |
return text | |
def get_functions_from_schema(filename): | |
schema = get_filetext(filename) | |
schema_json = json.loads(schema) | |
functions = schema_json.get("functions") | |
return functions | |
class StabilityAI: | |
def generate_image(cls, visualize_prompt): | |
print("visualize_prompt:"+visualize_prompt) | |
answers = stability_api.generate( | |
prompt=visualize_prompt, | |
) | |
for resp in answers: | |
for artifact in resp.artifacts: | |
if artifact.finish_reason == generation.FILTER: | |
print("NSFW") | |
if artifact.type == generation.ARTIFACT_IMAGE: | |
img = PIL.Image.open(io.BytesIO(artifact.binary)) | |
return img | |
class OpenAI: | |
def chat_completion(cls, prompt, start_with=""): | |
constraints = get_filetext(filename = "constraints.md") | |
template = get_filetext(filename = "template.md") | |
# ChatCompletion APIに渡すデータを定義する | |
data = { | |
"model": MODEL, | |
"messages": [ | |
{"role": "system", "content": constraints} | |
,{"role": "system", "content": template} | |
,{"role": "assistant", "content": "Sure!"} | |
,{"role": "user", "content": prompt} | |
,{"role": "assistant", "content": start_with} | |
], | |
} | |
# 文章生成にかかる時間を計測する | |
start = time.time() | |
# ChatCompletion APIを呼び出す | |
response = requests.post( | |
"https://api.openai.com/v1/chat/completions", | |
headers={ | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {openai.api_key}" | |
}, | |
json=data | |
) | |
print("gpt generation time: "+str(time.time() - start)) | |
# ChatCompletion APIから返された結果を取得する | |
result = response.json() | |
print(result) | |
content = result["choices"][0]["message"]["content"].strip() | |
visualize_prompt = content.split("### Prompt for Visual Expression\n\n")[1] | |
#print("split_content:"+split_content) | |
#if len(split_content) > 1: | |
# visualize_prompt = split_content[1] | |
#else: | |
# visualize_prompt = "vacant dish" | |
#print("visualize_prompt:"+visualize_prompt) | |
answers = stability_api.generate( | |
prompt=visualize_prompt, | |
) | |
def chat_completion_with_function(cls, prompt, messages, functions): | |
print("prompt:"+prompt) | |
# 文章生成にかかる時間を計測する | |
start = time.time() | |
# ChatCompletion APIを呼び出す | |
response = openai.ChatCompletion.create( | |
model=MODEL, | |
messages=messages, | |
functions=functions, | |
function_call={"name": "format_recipe"} | |
) | |
print("gpt generation time: "+str(time.time() - start)) | |
# ChatCompletion APIから返された結果を取得する | |
message = response.choices[0].message | |
print("chat completion message: " + json.dumps(message, indent=2)) | |
return message | |
class NajiminoAI: | |
def __init__(self, user_message): | |
self.user_message = user_message | |
def generate_recipe_prompt(self): | |
template = get_filetext(filename="template.md") | |
prompt = f""" | |
{self.user_message} | |
--- | |
上記を元に、下記テンプレートを埋めてください。 | |
--- | |
{template} | |
""" | |
return prompt | |
def format_recipe(self, lang, title, description, ingredients, instruction, comment_feelings_taste, explanation_to_blind_person, prompt_for_visual_expression): | |
template = get_filetext(filename = "template.md") | |
debug_message = template.format( | |
lang=lang, | |
title=title, | |
description=description, | |
ingredients=ingredients, | |
instruction=instruction, | |
comment_feelings_taste=comment_feelings_taste, | |
explanation_to_blind_person=explanation_to_blind_person, | |
prompt_for_visual_expression=prompt_for_visual_expression | |
) | |
print("debug_message: "+debug_message) | |
return debug_message | |
def generate(cls, user_message): | |
najiminoai = NajiminoAI(user_message) | |
return najiminoai.generate_recipe() | |
def generate_recipe(self): | |
user_message = self.user_message | |
constraints = get_filetext(filename = "constraints.md") | |
messages = [ | |
{"role": "system", "content": constraints} | |
,{"role": "user", "content": user_message} | |
] | |
functions = get_functions_from_schema('schema.json') | |
message = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions) | |
image = None | |
html = None | |
if message.get("function_call"): | |
function_name = message["function_call"]["name"] | |
args = json.loads(message["function_call"]["arguments"]) | |
lang=args.get("lang") | |
title=args.get("title") | |
description=args.get("description") | |
ingredients=args.get("ingredients") | |
instruction=args.get("instruction") | |
comment_feelings_taste=args.get("comment_feelings_taste") | |
explanation_to_blind_person=args.get("explanation_to_blind_person") | |
prompt_for_visual_expression_in_en=args.get("prompt_for_visual_expression_in_en") | |
prompt_for_visual_expression = \ | |
prompt_for_visual_expression_in_en \ | |
+ " delicious looking extremely detailed photo f1.2 (50mm|85mm) award winner depth of field bokeh perfect lighting " | |
print("prompt_for_visual_expression: "+prompt_for_visual_expression) | |
# 画像生成にかかる時間を計測する | |
start = time.time() | |
image = StabilityAI.generate_image(prompt_for_visual_expression) | |
print("image generation time: "+str(time.time() - start)) | |
function_response = self.format_recipe( | |
lang=lang, | |
title=title, | |
description=description, | |
ingredients=ingredients, | |
instruction=instruction, | |
comment_feelings_taste=comment_feelings_taste, | |
explanation_to_blind_person=explanation_to_blind_person, | |
prompt_for_visual_expression=prompt_for_visual_expression | |
) | |
html = ( | |
"<div style='max-width:100%; overflow:auto'>" | |
+ "<p>" | |
+ markdown2.markdown(function_response) | |
+ "</div>" | |
) | |
return [image, html] | |
def main(): | |
# インプット例をクリックした時のコールバック関数 | |
def click_example(example): | |
# クリックされたインプット例をテキストボックスに自動入力 | |
inputs.value = example | |
time.sleep(0.1) # テキストボックスに文字が表示されるまで待機 | |
# 自動入力後に実行ボタンをクリックして結果を表示 | |
execute_button.click() | |
iface = gr.Interface(fn=NajiminoAI.generate, | |
examples=[ | |
["ラー麺 スイカ かき氷 八ツ橋"], | |
["お好み焼き 鯖"], | |
["茹でたアスパラガスに合う季節のソース"], | |
], | |
inputs=gr.Textbox(label=inputs_label), | |
outputs=[ | |
gr.Image(label="Visual Expression"), | |
"html" | |
], | |
title=title, | |
description=description, | |
article=article | |
) | |
iface.launch() | |
if __name__ == '__main__': | |
function = '' | |
if len(sys.argv) > 1: | |
function = sys.argv[1] | |
if function == 'generate': | |
NajiminoAI.generate("グルテンフリーの香ばしいサバのお好み焼き") | |
elif function == 'generate_image': | |
image = StabilityAI.generate_image("Imagine a delicious gluten-free okonomiyaki with mackerel. The okonomiyaki is crispy on the outside and chewy on the inside. It is topped with savory sauce and creamy mayonnaise, creating a mouthwatering visual. The dish is garnished with finely chopped green onions and red pickled ginger, adding a pop of color. The mackerel fillets are beautifully grilled and placed on top of the okonomiyaki, adding a touch of elegance. The dish is served on a traditional Japanese plate, completing the visual presentation.") | |
print("image: " + image) | |
# imageが何のクラス確認する | |
if type(image) == PIL.PngImagePlugin.PngImageFile: | |
#save image | |
image.save("image.png") | |
else: | |
main() | |