File size: 2,155 Bytes
a0a6a64
e02e941
 
a0a6a64
 
 
 
e02e941
 
 
 
a0a6a64
 
3e0b719
a0a6a64
e02e941
16b4096
a0a6a64
 
65bcfd2
 
 
3e0b719
a0a6a64
3e0b719
7c5f8d7
4596c66
7c5f8d7
3e0b719
 
 
 
 
a0a6a64
c922e09
a0a6a64
 
 
 
16b4096
 
9d4f0df
16b4096
 
9d4f0df
42814c6
 
e25568a
 
 
 
 
a0a6a64
 
 
 
 
16b4096
d7aefcc
a0a6a64
e02e941
a8ec507
 
 
 
16b4096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba385a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import spaces
import torch
import gradio as gr


# cpu

zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' 🤔


# gpu
model = None

@spaces.GPU
def greet(prompts, separator):
    # print(zero.device) # <-- 'cuda:0' 🤗
    from vllm import SamplingParams, LLM
    from transformers.utils import move_cache
    from huggingface_hub import snapshot_download, login

    global model

    if model is None:
        LLM_MODEL_ID = "DoctorSlimm/trim-music-31"
        # LLM_MODEL_ID = "mistral-community/Mistral-7B-v0.2"
        # LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
        os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
        fp = snapshot_download(LLM_MODEL_ID, token=os.getenv('HF_TOKEN'), revision='main')
        move_cache()
        model = LLM(fp)
    
    sampling_params = dict(
            temperature = 0.01,
            ignore_eos = False,
            max_tokens = int(512 * 2)
    )
    sampling_params = SamplingParams(**sampling_params)

    multi_prompt = False
    separator = separator.strip()
    if separator in prompts:
        multi_prompt = True
        prompts = prompts.split(separator)
    else:
        prompts = [prompts]
    for idx, pt in enumerate(prompts):
        print()
        print(f'[{idx}]:')
        print(pt)
        
    model_outputs = model.generate(prompts, sampling_params)
    generations = []
    for output in model_outputs:
        for outputs in output.outputs:
            generations.append(outputs.text)
    if multi_prompt:
        return separator.join(generations)
    return generations[0]


## make predictions via api ##
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app

demo = gr.Interface(
    fn=greet,
    inputs=[
        gr.Text(
            value='hello sir!<SEP>bonjour madame...',
            placeholder='hello sir!<SEP>bonjour madame...',
            label='list of prompts separated by separator'
        ),
        gr.Text(
            value='<SEP>',
            placeholder='<SEP>',
            label='separator for your prompts'
        )],
    outputs=gr.Text()
)
demo.launch(share=True)