File size: 5,452 Bytes
761be79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# https://platform.openai.com/examples/default-emoji-translation
# https://zhuanlan.zhihu.com/p/672725319
import os
import numpy as np
from openai import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
import json
import random

def _to_abs_path(fn, dir):
    if not os.path.isabs(fn):
        fn = os.path.join(dir, fn)
    return fn

def _to_cache_path(dir):
    cache_dir = os.path.join(dir, 'cache')
    os.makedirs(cache_dir, exist_ok=True)
    return cache_dir

def _read_txt(fn):
    with open(fn, 'r', encoding='utf-8') as fp:
        text = fp.read()
    return text

class GPTHelper:
    def __init__(self, config) -> None:
        self.setup_config(config)

    def setup_config(self, config):
        self.config = config

        self.client = OpenAI(
            api_key=config.get('openai_api_key', os.getenv('openai_api_key')),
            base_url=config.get('openai_api_base', os.getenv('openai_api_base')),
        )

        self.embeddings = OpenAIEmbeddings(
            openai_api_key=config.get('openai_api_key', os.getenv('openai_api_key')), 
            openai_api_base=config.get('openai_api_base', os.getenv('openai_api_base')),
        )

        self.prepare_faiss()

        # prompts
        self.prompt_composition = _read_txt(_to_abs_path(self.config.composition_from, config.yaml_dir))
        self.prompt_image = _read_txt(_to_abs_path(self.config.image_from, config.yaml_dir))

    def prepare_faiss(self):
        json_fn = os.path.join(self.config.dat_dir, 'DB1404.json')
        with open(json_fn, 'r', encoding='utf-8') as fp:
            cc = json.load(fp)
        self.donbda_dict = cc
        self.donbda_texts = [cc[k] for k in cc]

        if os.path.exists(self.config.db_path):
            self.faiss_db = FAISS.load_local(self.config.db_path, self.embeddings, allow_dangerous_deserialization=True)
        else:
            self.faiss_db = FAISS.from_texts(self.donbda_texts, self.config.embeddings)
            self.faiss_db.save_local(self.config.db_path)

    # ask gpt for keywords in text
    def query_keywords(self, image_topic):
        # get keywords
        response = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": "你会得到一句话,请根据这句话给出几个创作简笔画的关键词,只回答关键词,不要回复其他内容,关键词用;隔开"
                },
                {
                    "role": "user",
                    "content": image_topic,
                }
            ],
            temperature=0.8,
            max_tokens=64,
            top_p=1
        )
        return response.choices[0].message.content.split(';')

    # query keyword in faiss db
    def query_in_faiss_db(self, keyword):
        query_results = self.faiss_db.similarity_search(keyword)
        words = []
        for rlt in query_results:
            w = rlt.page_content
            if w in self.donbda_texts:
                idxs = np.where(w == np.array(self.donbda_texts))[0]
                select_i = random.randint(0, len(idxs)-1)
                idx = idxs[select_i]

                words.append({
                    'idx': int(idx),
                    'word': w,
                })
        return words
        
    # query composition of the image
    def query_composition(self, image_topic, keywords_to_query, canvas_width, canvas_height, num_words=0):
        prompt = self.prompt_composition
        prompt = prompt.replace('%width%', f'{canvas_width}')
        prompt = prompt.replace('%height%', f'{canvas_height}')
        prompt = prompt.replace('%num_cols%', f'{int(np.ceil(canvas_width / 180))}')
        prompt = prompt.replace('%num_rows%', f'{int(np.ceil(canvas_height / 180))}')
        if num_words > 0:
            prompt = prompt.replace('%usage%', f'使用这些关键词,每个关键词可使用多次,总共应该出现 {num_words} 个,回答时按照从远到近的顺序。')
        else:
            prompt = prompt.replace('%usage%', '仅使用这些关键词,且每个关键词使用一次,回答时按照从远到近的顺序。')
        self.log(prompt)

        response = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": prompt,
                },
                {
                    "role": "user",
                    "content": f'主题是:"{image_topic}",关键词是:"{keywords_to_query}"',
                }
            ],
            temperature=0.8,
            max_tokens=4096,
            top_p=1
        )
        return response.choices[0].message.content
        
    # query prompt of the image
    def query_image_prompt(self, image_topic):
        response = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": self.prompt_image,
                },
                {
                    "role": "user",
                    "content": f'主题是:"{image_topic}"',
                }
            ],
            temperature=0.8,
            max_tokens=4096,
            top_p=1
        )
        return response.choices[0].message.content