File size: 3,768 Bytes
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcef1cd
f53833a
7e4123a
 
 
 
 
 
 
d3f7f98
7e4123a
 
 
d3f7f98
7e4123a
 
 
d3f7f98
7e4123a
 
 
d3f7f98
7e4123a
 
 
 
48af876
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import copy
import asyncio
import google.generativeai as genai

from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt

class GeminiChatPromptFmt(PromptFmt):
    @classmethod
    def ctx(cls, context):
        if context is None or context == "":
            return None
        else:
            return  {
                "role": "system",
                "parts": [context]
            }
    
    @classmethod
    def prompt(cls, pingpong, truncate_size):
        ping = pingpong.ping[:truncate_size]
        pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
        result = [
            {
                "role": "user",
                "parts": [ping]
            }
        ]
        if pong != "":
            result = result + [
                {
                    "role": "model",
                    "parts": [pong]
                }
            ]

        return result

class GeminiChatPPManager(PPManager):
    def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=GeminiChatPromptFmt, truncate_size: int=None):
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)
        
        pingpongs = copy.deepcopy(self.pingpongs)
        ctx = fmt.ctx(self.ctx)
        ctx = ctx['parts'][0] if ctx is not None else ""
        results = []
        
        for idx, pingpong in enumerate(pingpongs[from_idx:to_idx]):
            if idx == 0:
                pingpong.ping = f"SYSTEM: {ctx} ----------- \n" + pingpong.ping
            results += fmt.prompt(pingpong, truncate_size=truncate_size)
            
        return results        

class GradioGeminiChatPPManager(GeminiChatPPManager):
    def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)
        
        results = []
        
        for pingpong in self.pingpongs[from_idx:to_idx]:
            results.append(fmt.ui(pingpong))
            
        return results

def init(api_key):
    genai.configure(api_key=api_key)

def _default_gen_text():
    return {
        "temperature": 0.9,
        "top_p": 0.8,
        "top_k": 32,
        "max_output_tokens": 2048,
    }

def _default_safety_settings():
    return [
        {
            "category": "HARM_CATEGORY_HARASSMENT",
            "threshold": "BLOCK_ONLY_HIGH"
        },
        {
            "category": "HARM_CATEGORY_HATE_SPEECH",
            "threshold": "BLOCK_ONLY_HIGH"
        },
        {
            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "threshold": "BLOCK_ONLY_HIGH"
        },
        {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_ONLY_HIGH"
        },
    ]

async def _word_generator(sentence):
    for word in sentence.split(" "):
        yield word
        delay = 0.03 + (len(word) * 0.005)
        await asyncio.sleep(delay)  # Simulate a short delay

async def gen_text(
    prompts,
    gen_config=_default_gen_text(),
    safety_settings=_default_safety_settings(),
    stream=True
):
    model = genai.GenerativeModel(model_name="gemini-1.0-pro",
                                generation_config=gen_config,
                                safety_settings=safety_settings)
    
    user_prompt = prompts[-1]
    prompts = prompts[:-1]
    convo = model.start_chat(history=prompts)

    resps = await convo.send_message_async(
        user_prompt["parts"][0], stream=stream
    )

    async for resp in resps:
        async for word in _word_generator(resp.text):
            yield word + " "