|
import spaces |
|
import gradio as gr |
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,AwqConfig |
|
import torch |
|
import os |
|
key = os.environ.get("key") |
|
from huggingface_hub import login |
|
login(key) |
|
|
|
|
|
nf4_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
model_id = "CohereForAI/c4ai-command-r-v01" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config=nf4_config) |
|
|
|
@spaces.GPU |
|
def generate_response(user_input, max_new_tokens, temperature): |
|
os.system("nvidia-smi") |
|
messages = [{"role": "user", "content": user_input}] |
|
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt") |
|
input_ids = input_ids.to(model.device) |
|
gen_tokens = model.generate( |
|
input_ids = input_ids, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
) |
|
|
|
gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True) |
|
if gen_text.startswith(user_input): |
|
gen_text = gen_text[len(user_input):].lstrip() |
|
|
|
return gen_text |
|
|
|
|
|
|
|
examples = [ |
|
{"message": "What is the weather like today?", "max_new_tokens": 250, "temperature": 0.5}, |
|
{"message": "Tell me a joke.", "max_new_tokens": 650, "temperature": 0.7}, |
|
{"message": "Explain the concept of machine learning.", "max_new_tokens": 980, "temperature": 0.4} |
|
] |
|
example_choices = [f"Example {i+1}" for i in range(len(examples))] |
|
|
|
def load_example(choice): |
|
index = example_choices.index(choice) |
|
example = examples[index] |
|
return example["message"], example["max_new_tokens"], example["temperature"] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
max_new_tokens_slider = gr.Slider(minimum=100, maximum=4000, value=980, label="Max New Tokens") |
|
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.3, label="Temperature") |
|
message_box = gr.Textbox(lines=2, label="Your Message") |
|
generate_button = gr.Button("Try🫡Command-R") |
|
output_box = gr.Textbox(label="🫡Command-R") |
|
|
|
generate_button.click( |
|
fn=generate_response, |
|
inputs=[message_box, max_new_tokens_slider, temperature_slider], |
|
outputs=output_box |
|
) |
|
example_dropdown = gr.Dropdown(label="🫡Load Example", choices=example_choices) |
|
example_button = gr.Button("🫡Load") |
|
example_button.click( |
|
fn=load_example, |
|
inputs=example_dropdown, |
|
outputs=[message_box, max_new_tokens_slider, temperature_slider] |
|
) |
|
|
|
demo.launch() |
|
|