File size: 5,899 Bytes
dcd2001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82934e3
dcd2001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import math

import gradio as gr


def make_chatbots(output_label0, output_label0_model2, **kwargs):
    text_outputs = []
    chat_kwargs = []
    for model_state_lock in kwargs['model_states']:
        if os.environ.get('DEBUG_MODEL_LOCK'):
            model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
        else:
            model_name = model_state_lock["base_model"]
        output_label = f'h2oGPT [{model_name}]'
        min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
        chat_kwargs.append(dict(label=output_label, visible=kwargs['model_lock'], elem_classes='chatsmall',
                                height=kwargs['height'] or 400, min_width=min_width))

    if kwargs['model_lock_columns'] == -1:
        kwargs['model_lock_columns'] = len(kwargs['model_states'])
    if kwargs['model_lock_columns'] is None:
        kwargs['model_lock_columns'] = 3

    ncols = kwargs['model_lock_columns']
    if kwargs['model_states'] == 0:
        nrows = 0
    else:
        nrows = math.ceil(len(kwargs['model_states']) / kwargs['model_lock_columns'])

    if kwargs['model_lock_columns'] == 0:
        # not using model_lock
        pass
    elif nrows <= 1:
        with gr.Row():
            for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
    elif nrows == kwargs['model_states']:
        with gr.Row():
            for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
    elif nrows == 2:
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii >= len(kwargs['model_states']) / 2:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii < len(kwargs['model_states']) / 2:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
    elif nrows == 3:
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii >= 1 * len(kwargs['model_states']) / 3:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii < 1 * len(kwargs['model_states']) / 3 or mii >= 2 * len(kwargs['model_states']) / 3:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii < 2 * len(kwargs['model_states']) / 3:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
    elif nrows >= 4:
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii >= 1 * len(kwargs['model_states']) / 4:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii < 1 * len(kwargs['model_states']) / 4 or mii >= 2 * len(kwargs['model_states']) / 4:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii < 2 * len(kwargs['model_states']) / 4 or mii >= 3 * len(kwargs['model_states']) / 4:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))
        with gr.Row():
            for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
                if mii < 3 * len(kwargs['model_states']) / 4:
                    continue
                text_outputs.append(gr.Chatbot(**chat_kwargs1))

    with gr.Row():
        text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
        text_output2 = gr.Chatbot(label=output_label0_model2,
                                  visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
    return text_output, text_output2, text_outputs


def make_prompt_form(kwargs, LangChainMode):
    if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
        extra_prompt_form = ".  For summarization, empty submission uses first top_k_docs documents."
    else:
        extra_prompt_form = ""
    if kwargs['input_lines'] > 1:
        instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
    else:
        instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form

    with gr.Row():#elem_id='prompt-form-area'):
        with gr.Column(scale=50):
            instruction = gr.Textbox(
                lines=kwargs['input_lines'],
                label='Ask anything',
                placeholder=instruction_label,
                info=None,
                elem_id='prompt-form',
                container=True,
            )
        with gr.Row():
            submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
            stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')

    return instruction, submit, stop_btn