Spaces:
Sleeping
Sleeping
File size: 5,452 Bytes
761be79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# https://platform.openai.com/examples/default-emoji-translation
# https://zhuanlan.zhihu.com/p/672725319
import os
import numpy as np
from openai import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
import json
import random
def _to_abs_path(fn, dir):
if not os.path.isabs(fn):
fn = os.path.join(dir, fn)
return fn
def _to_cache_path(dir):
cache_dir = os.path.join(dir, 'cache')
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
def _read_txt(fn):
with open(fn, 'r', encoding='utf-8') as fp:
text = fp.read()
return text
class GPTHelper:
def __init__(self, config) -> None:
self.setup_config(config)
def setup_config(self, config):
self.config = config
self.client = OpenAI(
api_key=config.get('openai_api_key', os.getenv('openai_api_key')),
base_url=config.get('openai_api_base', os.getenv('openai_api_base')),
)
self.embeddings = OpenAIEmbeddings(
openai_api_key=config.get('openai_api_key', os.getenv('openai_api_key')),
openai_api_base=config.get('openai_api_base', os.getenv('openai_api_base')),
)
self.prepare_faiss()
# prompts
self.prompt_composition = _read_txt(_to_abs_path(self.config.composition_from, config.yaml_dir))
self.prompt_image = _read_txt(_to_abs_path(self.config.image_from, config.yaml_dir))
def prepare_faiss(self):
json_fn = os.path.join(self.config.dat_dir, 'DB1404.json')
with open(json_fn, 'r', encoding='utf-8') as fp:
cc = json.load(fp)
self.donbda_dict = cc
self.donbda_texts = [cc[k] for k in cc]
if os.path.exists(self.config.db_path):
self.faiss_db = FAISS.load_local(self.config.db_path, self.embeddings, allow_dangerous_deserialization=True)
else:
self.faiss_db = FAISS.from_texts(self.donbda_texts, self.config.embeddings)
self.faiss_db.save_local(self.config.db_path)
# ask gpt for keywords in text
def query_keywords(self, image_topic):
# get keywords
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "你会得到一句话,请根据这句话给出几个创作简笔画的关键词,只回答关键词,不要回复其他内容,关键词用;隔开"
},
{
"role": "user",
"content": image_topic,
}
],
temperature=0.8,
max_tokens=64,
top_p=1
)
return response.choices[0].message.content.split(';')
# query keyword in faiss db
def query_in_faiss_db(self, keyword):
query_results = self.faiss_db.similarity_search(keyword)
words = []
for rlt in query_results:
w = rlt.page_content
if w in self.donbda_texts:
idxs = np.where(w == np.array(self.donbda_texts))[0]
select_i = random.randint(0, len(idxs)-1)
idx = idxs[select_i]
words.append({
'idx': int(idx),
'word': w,
})
return words
# query composition of the image
def query_composition(self, image_topic, keywords_to_query, canvas_width, canvas_height, num_words=0):
prompt = self.prompt_composition
prompt = prompt.replace('%width%', f'{canvas_width}')
prompt = prompt.replace('%height%', f'{canvas_height}')
prompt = prompt.replace('%num_cols%', f'{int(np.ceil(canvas_width / 180))}')
prompt = prompt.replace('%num_rows%', f'{int(np.ceil(canvas_height / 180))}')
if num_words > 0:
prompt = prompt.replace('%usage%', f'使用这些关键词,每个关键词可使用多次,总共应该出现 {num_words} 个,回答时按照从远到近的顺序。')
else:
prompt = prompt.replace('%usage%', '仅使用这些关键词,且每个关键词使用一次,回答时按照从远到近的顺序。')
self.log(prompt)
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": prompt,
},
{
"role": "user",
"content": f'主题是:"{image_topic}",关键词是:"{keywords_to_query}"',
}
],
temperature=0.8,
max_tokens=4096,
top_p=1
)
return response.choices[0].message.content
# query prompt of the image
def query_image_prompt(self, image_topic):
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": self.prompt_image,
},
{
"role": "user",
"content": f'主题是:"{image_topic}"',
}
],
temperature=0.8,
max_tokens=4096,
top_p=1
)
return response.choices[0].message.content
|