Spaces:
Running
Running
import os | |
import spaces | |
import torch | |
import gradio as gr | |
# cpu | |
zero = torch.Tensor([0]).cuda() | |
print(zero.device) # <-- 'cpu' 🤔 | |
# gpu | |
model = None | |
def greet(prompts, separator): | |
# print(zero.device) # <-- 'cuda:0' 🤗 | |
from vllm import SamplingParams, LLM | |
from transformers.utils import move_cache | |
from huggingface_hub import snapshot_download, login | |
global model | |
if model is None: | |
# LLM_MODEL_ID = "DoctorSlimm/trim-music-31" | |
# LLM_MODEL_ID = "mistral-community/Mistral-7B-v0.2" | |
LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2" | |
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' | |
fp = snapshot_download(LLM_MODEL_ID, token=os.getenv('HF_TOKEN'), revision='main') | |
move_cache() | |
model = LLM(fp) | |
sampling_params = dict( | |
temperature = 0.01, | |
ignore_eos = False, | |
max_tokens = int(512 * 2) | |
) | |
sampling_params = SamplingParams(**sampling_params) | |
multi_prompt = False | |
separator = separator.strip() | |
if separator in prompts: | |
multi_prompt = True | |
prompts = prompts.split(separator) | |
else: | |
prompts = [prompts] | |
for idx, pt in enumerate(prompts): | |
print() | |
print(f'[{idx}]:') | |
print(pt) | |
model_outputs = model.generate(prompts, sampling_params) | |
generations = [] | |
for output in model_outputs: | |
for outputs in output.outputs: | |
generations.append(outputs.text) | |
if multi_prompt: | |
return generations | |
return generations[0] | |
## make predictions via api ## | |
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app | |
demo = gr.Interface( | |
fn=greet, | |
inputs=[ | |
gr.Text( | |
value='hello sir!<SEP>bonjour madame...', | |
placeholder='hello sir!<SEP>bonjour madame...', | |
label='list of prompts separated by separator' | |
), | |
gr.Text( | |
value='<SEP>', | |
placeholder='<SEP>', | |
label='separator for your prompts' | |
)], | |
outputs=gr.Text() | |
) | |
demo.launch(share=True) |