File size: 4,118 Bytes
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import toml
from pathlib import Path
import google.generativeai as palm_api

from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt

from .utils import set_palm_api_key


# Set PaLM API Key
set_palm_api_key()

# Load PaLM Prompt Templates
palm_prompts = toml.load(Path('.') / 'assets' / 'palm_prompts.toml')

class PaLMChatPromptFmt(PromptFmt):
    @classmethod
    def ctx(cls, context):
        pass

    @classmethod
    def prompt(cls, pingpong, truncate_size):
        ping = pingpong.ping[:truncate_size]
        pong = pingpong.pong
        
        if pong is None or pong.strip() == "":
            return [
                {
                    "author": "USER",
                    "content": ping
                },
            ]
        else:
            pong = pong[:truncate_size]

            return [
                {
                    "author": "USER",
                    "content": ping
                },
                {
                    "author": "AI",
                    "content": pong
                },
            ]

class PaLMChatPPManager(PPManager):
    def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None):
        results = []
        
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
            results += fmt.prompt(pingpong, truncate_size=truncate_size)

        return results    

class GradioPaLMChatPPManager(PaLMChatPPManager):
    def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        results = []

        for pingpong in self.pingpongs[from_idx:to_idx]:
            results.append(fmt.ui(pingpong))

        return results    

async def gen_text(
    prompt,
    mode="chat", #chat or text
    parameters=None,
    use_filter=True
):
    if parameters is None:
        temperature = 1.0
        top_k = 40
        top_p = 0.95
        max_output_tokens = 1024
        
        # default safety settings
        safety_settings = [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1},
                           {"category":"HARM_CATEGORY_TOXICITY","threshold":1},
                           {"category":"HARM_CATEGORY_VIOLENCE","threshold":2},
                           {"category":"HARM_CATEGORY_SEXUAL","threshold":2},
                           {"category":"HARM_CATEGORY_MEDICAL","threshold":2},
                           {"category":"HARM_CATEGORY_DANGEROUS","threshold":2}]
        if not use_filter:
            for idx, _ in enumerate(safety_settings):
                safety_settings[idx]['threshold'] = 4

        if mode == "chat":
            parameters = {
                'model': 'models/chat-bison-001',
                'candidate_count': 1,
                'context': "",
                'temperature': temperature,
                'top_k': top_k,
                'top_p': top_p,
            }
        else:
            parameters = {
                'model': 'models/text-bison-001',
                'candidate_count': 1,
                'temperature': temperature,
                'top_k': top_k,
                'top_p': top_p,
                'max_output_tokens': max_output_tokens,
                'safety_settings': safety_settings,
            }

    if mode == "chat":
        response = await palm_api.chat_async(**parameters, messages=prompt)
    else:
        response = palm_api.generate_text(**parameters, prompt=prompt)
    
    if use_filter and len(response.filters) > 0 and \
        response.filters[0]['reason'] == 2:
        response_txt = "your request is blocked for some reasons"
    else:
        if mode == "chat":
            response_txt = response.last
        else:
            response_txt = response.result
    
    return response, response_txt