Spaces:
Build error
Build error
File size: 4,118 Bytes
a1ca2de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import toml
from pathlib import Path
import google.generativeai as palm_api
from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt
from .utils import set_palm_api_key
# Set PaLM API Key
set_palm_api_key()
# Load PaLM Prompt Templates
palm_prompts = toml.load(Path('.') / 'assets' / 'palm_prompts.toml')
class PaLMChatPromptFmt(PromptFmt):
@classmethod
def ctx(cls, context):
pass
@classmethod
def prompt(cls, pingpong, truncate_size):
ping = pingpong.ping[:truncate_size]
pong = pingpong.pong
if pong is None or pong.strip() == "":
return [
{
"author": "USER",
"content": ping
},
]
else:
pong = pong[:truncate_size]
return [
{
"author": "USER",
"content": ping
},
{
"author": "AI",
"content": pong
},
]
class PaLMChatPPManager(PPManager):
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None):
results = []
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
results += fmt.prompt(pingpong, truncate_size=truncate_size)
return results
class GradioPaLMChatPPManager(PaLMChatPPManager):
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
results = []
for pingpong in self.pingpongs[from_idx:to_idx]:
results.append(fmt.ui(pingpong))
return results
async def gen_text(
prompt,
mode="chat", #chat or text
parameters=None,
use_filter=True
):
if parameters is None:
temperature = 1.0
top_k = 40
top_p = 0.95
max_output_tokens = 1024
# default safety settings
safety_settings = [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1},
{"category":"HARM_CATEGORY_TOXICITY","threshold":1},
{"category":"HARM_CATEGORY_VIOLENCE","threshold":2},
{"category":"HARM_CATEGORY_SEXUAL","threshold":2},
{"category":"HARM_CATEGORY_MEDICAL","threshold":2},
{"category":"HARM_CATEGORY_DANGEROUS","threshold":2}]
if not use_filter:
for idx, _ in enumerate(safety_settings):
safety_settings[idx]['threshold'] = 4
if mode == "chat":
parameters = {
'model': 'models/chat-bison-001',
'candidate_count': 1,
'context': "",
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
}
else:
parameters = {
'model': 'models/text-bison-001',
'candidate_count': 1,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'max_output_tokens': max_output_tokens,
'safety_settings': safety_settings,
}
if mode == "chat":
response = await palm_api.chat_async(**parameters, messages=prompt)
else:
response = palm_api.generate_text(**parameters, prompt=prompt)
if use_filter and len(response.filters) > 0 and \
response.filters[0]['reason'] == 2:
response_txt = "your request is blocked for some reasons"
else:
if mode == "chat":
response_txt = response.last
else:
response_txt = response.result
return response, response_txt |