DongbaDreamer / model /dongda_gpt_helper.py
initialneil's picture
- update space
761be79
raw
history blame contribute delete
No virus
5.45 kB
# https://platform.openai.com/examples/default-emoji-translation
# https://zhuanlan.zhihu.com/p/672725319
import os
import numpy as np
from openai import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
import json
import random
def _to_abs_path(fn, dir):
if not os.path.isabs(fn):
fn = os.path.join(dir, fn)
return fn
def _to_cache_path(dir):
cache_dir = os.path.join(dir, 'cache')
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
def _read_txt(fn):
with open(fn, 'r', encoding='utf-8') as fp:
text = fp.read()
return text
class GPTHelper:
def __init__(self, config) -> None:
self.setup_config(config)
def setup_config(self, config):
self.config = config
self.client = OpenAI(
api_key=config.get('openai_api_key', os.getenv('openai_api_key')),
base_url=config.get('openai_api_base', os.getenv('openai_api_base')),
)
self.embeddings = OpenAIEmbeddings(
openai_api_key=config.get('openai_api_key', os.getenv('openai_api_key')),
openai_api_base=config.get('openai_api_base', os.getenv('openai_api_base')),
)
self.prepare_faiss()
# prompts
self.prompt_composition = _read_txt(_to_abs_path(self.config.composition_from, config.yaml_dir))
self.prompt_image = _read_txt(_to_abs_path(self.config.image_from, config.yaml_dir))
def prepare_faiss(self):
json_fn = os.path.join(self.config.dat_dir, 'DB1404.json')
with open(json_fn, 'r', encoding='utf-8') as fp:
cc = json.load(fp)
self.donbda_dict = cc
self.donbda_texts = [cc[k] for k in cc]
if os.path.exists(self.config.db_path):
self.faiss_db = FAISS.load_local(self.config.db_path, self.embeddings, allow_dangerous_deserialization=True)
else:
self.faiss_db = FAISS.from_texts(self.donbda_texts, self.config.embeddings)
self.faiss_db.save_local(self.config.db_path)
# ask gpt for keywords in text
def query_keywords(self, image_topic):
# get keywords
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "你会得到一句话,请根据这句话给出几个创作简笔画的关键词,只回答关键词,不要回复其他内容,关键词用;隔开"
},
{
"role": "user",
"content": image_topic,
}
],
temperature=0.8,
max_tokens=64,
top_p=1
)
return response.choices[0].message.content.split(';')
# query keyword in faiss db
def query_in_faiss_db(self, keyword):
query_results = self.faiss_db.similarity_search(keyword)
words = []
for rlt in query_results:
w = rlt.page_content
if w in self.donbda_texts:
idxs = np.where(w == np.array(self.donbda_texts))[0]
select_i = random.randint(0, len(idxs)-1)
idx = idxs[select_i]
words.append({
'idx': int(idx),
'word': w,
})
return words
# query composition of the image
def query_composition(self, image_topic, keywords_to_query, canvas_width, canvas_height, num_words=0):
prompt = self.prompt_composition
prompt = prompt.replace('%width%', f'{canvas_width}')
prompt = prompt.replace('%height%', f'{canvas_height}')
prompt = prompt.replace('%num_cols%', f'{int(np.ceil(canvas_width / 180))}')
prompt = prompt.replace('%num_rows%', f'{int(np.ceil(canvas_height / 180))}')
if num_words > 0:
prompt = prompt.replace('%usage%', f'使用这些关键词,每个关键词可使用多次,总共应该出现 {num_words} 个,回答时按照从远到近的顺序。')
else:
prompt = prompt.replace('%usage%', '仅使用这些关键词,且每个关键词使用一次,回答时按照从远到近的顺序。')
self.log(prompt)
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": prompt,
},
{
"role": "user",
"content": f'主题是:"{image_topic}",关键词是:"{keywords_to_query}"',
}
],
temperature=0.8,
max_tokens=4096,
top_p=1
)
return response.choices[0].message.content
# query prompt of the image
def query_image_prompt(self, image_topic):
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": self.prompt_image,
},
{
"role": "user",
"content": f'主题是:"{image_topic}"',
}
],
temperature=0.8,
max_tokens=4096,
top_p=1
)
return response.choices[0].message.content