Spaces:
Running
Running
File size: 2,155 Bytes
a0a6a64 e02e941 a0a6a64 e02e941 a0a6a64 3e0b719 a0a6a64 e02e941 16b4096 a0a6a64 65bcfd2 3e0b719 a0a6a64 3e0b719 7c5f8d7 4596c66 7c5f8d7 3e0b719 a0a6a64 c922e09 a0a6a64 16b4096 9d4f0df 16b4096 9d4f0df 42814c6 e25568a a0a6a64 16b4096 d7aefcc a0a6a64 e02e941 a8ec507 16b4096 9ba385a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
import spaces
import torch
import gradio as gr
# cpu
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' 🤔
# gpu
model = None
@spaces.GPU
def greet(prompts, separator):
# print(zero.device) # <-- 'cuda:0' 🤗
from vllm import SamplingParams, LLM
from transformers.utils import move_cache
from huggingface_hub import snapshot_download, login
global model
if model is None:
LLM_MODEL_ID = "DoctorSlimm/trim-music-31"
# LLM_MODEL_ID = "mistral-community/Mistral-7B-v0.2"
# LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
fp = snapshot_download(LLM_MODEL_ID, token=os.getenv('HF_TOKEN'), revision='main')
move_cache()
model = LLM(fp)
sampling_params = dict(
temperature = 0.01,
ignore_eos = False,
max_tokens = int(512 * 2)
)
sampling_params = SamplingParams(**sampling_params)
multi_prompt = False
separator = separator.strip()
if separator in prompts:
multi_prompt = True
prompts = prompts.split(separator)
else:
prompts = [prompts]
for idx, pt in enumerate(prompts):
print()
print(f'[{idx}]:')
print(pt)
model_outputs = model.generate(prompts, sampling_params)
generations = []
for output in model_outputs:
for outputs in output.outputs:
generations.append(outputs.text)
if multi_prompt:
return separator.join(generations)
return generations[0]
## make predictions via api ##
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
demo = gr.Interface(
fn=greet,
inputs=[
gr.Text(
value='hello sir!<SEP>bonjour madame...',
placeholder='hello sir!<SEP>bonjour madame...',
label='list of prompts separated by separator'
),
gr.Text(
value='<SEP>',
placeholder='<SEP>',
label='separator for your prompts'
)],
outputs=gr.Text()
)
demo.launch(share=True) |