File size: 2,671 Bytes
88f55d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import re
import copy
import global_vars
from threading import Thread
from transformers import TextIteratorStreamer
from transformers import GenerationConfig

def contains_image_markdown(string):
    regex = re.compile(r'!\[(.*?)\]\((.*?)\)')
    match = regex.search(string)
    return match

def build_model_inputs(prompt, model_num, return_token_type_ids):
    model_inputs = global_vars.models[model_num]["tokenizer"](
        [prompt], 
        return_tensors="pt",
        return_token_type_ids=return_token_type_ids
    ).to("cuda")
    return model_inputs

def build_streamer(
    model_num,
    timeout=20.,
    skip_prompt=True,
    skip_special_tokens=True
):
    streamer = TextIteratorStreamer(
        global_vars.models[model_num]["tokenizer"], 
        timeout=timeout, 
        skip_prompt=skip_prompt,
        skip_special_tokens=skip_special_tokens
    )
    return streamer


def build_gen_config(
    temperature, top_p, top_k, repetition_penalty, max_new_tokens, 
    num_beams, use_cache, do_sample, eos_token_id, pad_token_id 
):
    gen_config_raw = {
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "max_new_tokens": max_new_tokens,
        "num_beams": num_beams,
        "use_cache": use_cache,
        "do_sample": do_sample,
        "eos_token_id": eos_token_id, 
        "pad_token_id": pad_token_id
    }

    return gen_config_raw, GenerationConfig(**gen_config_raw)

def build_gen_kwargs(
    gen_config,
    model_inputs,
    streamer,
    stopping_criteria
):
    gen_kwargs = dict(
        model_inputs,
        streamer=streamer,
        stopping_criteria=stopping_criteria
    )
    gen_kwargs.update(gen_config)
    return gen_kwargs 

def start_gen(gen_kwargs, model_num):
    t = Thread(
        target=global_vars.models[model_num]["model"].generate,
        kwargs=gen_kwargs
    )
    t.start()

def build(
    prompt, model_num,
    temperature, top_p, top_k, repetition_penalty, max_new_tokens, 
    num_beams, use_cache, do_sample, eos_token_id, pad_token_id,
    stopping_criteria=None, return_token_type_ids=True
):
    gen_config_raw, _ = build_gen_config(
        temperature, top_p, top_k, repetition_penalty, max_new_tokens, 
        num_beams, use_cache, do_sample, eos_token_id, pad_token_id 
    )

    model_inputs = build_model_inputs(
        prompt, model_num, return_token_type_ids=return_token_type_ids
    )
    streamer = build_streamer(model_num)
    gen_kwargs = build_gen_kwargs(
        gen_config_raw, 
        model_inputs, 
        streamer,
        stopping_criteria
    )
    return gen_kwargs, streamer