import gradio as gr import torch from gradio.themes.utils import sizes from transformers import AutoModelForCausalLM, AutoTokenizer import utils from constants import END_OF_TEXT from settings import DEFAULT_PORT # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained( "BEE-spoke-data/smol_llama-101M-GQA-python", use_fast=False, ) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = END_OF_TEXT model = AutoModelForCausalLM.from_pretrained( "BEE-spoke-data/smol_llama-101M-GQA-python", device_map="auto", ) model = torch.compile(model, mode="reduce-overhead") # UI things _styles = utils.get_file_as_string("styles.css") # Loads ./README.md file & splits it into sections readme_file_content = utils.get_file_as_string("README.md", path="./") ( manifest, description, disclaimer, base_model_info, formats, ) = utils.get_sections(readme_file_content, "---", up_to=5) theme = gr.themes.Soft( primary_hue="yellow", secondary_hue="orange", neutral_hue="slate", radius_size=sizes.radius_sm, font=[ gr.themes.GoogleFont("IBM Plex Sans", [400, 600]), "ui-sans-serif", "system-ui", "sans-serif", ], text_size=sizes.text_lg, ) def run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, min_new_tokens=8, renormalize_logits=True, no_repeat_ngram_size=6, repetition_penalty=repetition_penalty, num_beams=3, early_stopping=True, do_sample=True, temperature=temperature, top_p=top_p, ) text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return text # Gradio interface wrapper for inference def gradio_interface( prompt: str, temperature: float, max_new_tokens: int, top_p: float, repetition_penalty: float, ): return run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty) import random examples = [ ["def add_numbers(a, b):\n return", 0.2, 192, 0.9, 1.2], [ "class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):", 0.2, 192, 0.9, 1.2, ], [ "import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda", 0.2, 192, 0.9, 1.2, ], [ "def factorial(n):\n if n == 0:\n return 1\n else:", 0.2, 192, 0.9, 1.2, ], [ 'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:', 0.2, 192, 0.9, 1.2, ], [ "import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot", 0.2, 192, 0.9, 1.2, ], ["def reverse_string(s:str) -> str:\n return", 0.2, 192, 0.9, 1.2], ["def is_palindrome(word:str) -> bool:\n return", 0.2, 192, 0.9, 1.2], [ "def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):", 0.2, 192, 0.9, 1.2, ], [ "def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:", 0.2, 192, 0.9, 1.2, ], ] # Define the Gradio Blocks interface with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo: with gr.Column(): gr.Markdown(description) with gr.Row(): with gr.Column(): instruction = gr.Textbox( value=random.choice([e[0] for e in examples]), placeholder="Enter your code here", label="Code", elem_id="q-input", ) submit = gr.Button("Generate", variant="primary") output = gr.Code(elem_id="q-output", language="python", lines=10) with gr.Row(): with gr.Column(): with gr.Accordion("Advanced settings", open=False): with gr.Row(): column_1, column_2 = gr.Column(), gr.Column() with column_1: temperature = gr.Slider( label="Temperature", value=0.2, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ) max_new_tokens = gr.Slider( label="Max new tokens", value=128, minimum=0, maximum=512, step=64, interactive=True, info="Number of tokens to generate", ) with column_2: top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ) repetition_penalty = gr.Slider( label="Repetition penalty", value=1.1, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) with gr.Column(): version = gr.Dropdown( [ "smol_llama-101M-GQA-python", ], value="smol_llama-101M-GQA-python", label="Version", info="", ) gr.Markdown(disclaimer) gr.Examples( examples=examples, inputs=[ instruction, temperature, max_new_tokens, top_p, repetition_penalty, version, ], cache_examples=False, fn=gradio_interface, outputs=[output], ) gr.Markdown(base_model_info) gr.Markdown(formats) submit.click( gradio_interface, inputs=[ instruction, temperature, max_new_tokens, top_p, repetition_penalty, ], outputs=[output], # preprocess=False, max_batch_size=2, show_progress=True, ) demo.queue(max_size=10).launch( debug=True, server_port=DEFAULT_PORT, max_threads=utils.get_workers(), )