# Importing required libraries import warnings warnings.filterwarnings("ignore") import os import sys from llama_cpp import Llama import gradio as gr from huggingface_hub import hf_hub_download from typing import List, Tuple, Generator from logger import logging # Assuming you have a logger.py from exception import CustomExceptionHandling # Assuming you have exception.py import spaces # Comment out for local use # Download gguf model files (Simplified for the specified models) huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Ensure token is set def download_model(repo_id, filename): try: hf_hub_download( repo_id=repo_id, filename=filename, local_dir="./models", token=huggingface_token, # Use token directly ) logging.info(f"Successfully downloaded {filename} from {repo_id}") except Exception as e: logging.error(f"Error downloading {filename} from {repo_id}: {e}") raise # Re-raise to halt execution if download fails # Only download if the files don't already exist. This is crucial. if not os.path.exists("./models/google.gemma-3-1b-pt.Q4_K_M.gguf"): download_model("DevQuasar/google.gemma-3-1b-pt-GGUF", "google.gemma-3-1b-pt.Q4_K_M.gguf") if not os.path.exists("./models/google.gemma-3-12b-pt.Q4_K_M.gguf"): download_model("DevQuasar/google.gemma-3-12b-pt-GGUF", "google.gemma-3-12b-pt.Q4_K_M.gguf") if not os.path.exists("./models/google.gemma-3-4b-pt.Q4_K_M.gguf"): # Example from original, in case needed. download_model("DevQuasar/google.gemma-3-4b-pt-GGUF", "google.gemma-3-4b-pt.Q4_K_M.gguf") if not os.path.exists("./models/google.gemma-3-27b-pt.Q4_K_M.gguf"): # Example from original, in case needed. download_model("DevQuasar/google.gemma-3-27b-pt-GGUF", "google.gemma-3-27b-pt.Q4_K_M.gguf") # Set the title and description title = "Gemma 3 Text Generation" description = """Gemma models for text generation and notebook continuation. This interface is designed for generating text continuations, not for interactive chat.""" llm_model = None # Only track the model name @spaces.GPU # Comment this line out for local execution def generate_text( prompt: str, model: str, max_tokens: int, temperature: float, n_ctx: int, top_p: float, top_k: int, repeat_penalty: float, ) -> Generator[str, None, None]: """ Generates text based on a prompt, using the specified Gemma model. Args: prompt (str): The initial text to continue. model (str): The model file to use (without path). max_tokens (int): Maximum number of tokens to generate. temperature (float): Controls randomness. top_p (float): Nucleus sampling parameter. top_k (int): Top-k sampling parameter. repeat_penalty (float): Penalty for repeating tokens. Yields: str: Generated text chunks, streamed as they become available. """ try: global llm_model global llm # Declare llm as global model_path = os.path.join("models", model) if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # Load the model (only if it's a new model) if llm_model != model: logging.info(f"Loading model: {model}") llm = Llama( model_path=model_path, flash_attn=True, n_gpu_layers=999, # Let llama-cpp handle this automatically, or set a reasonable value. n_ctx=n_ctx, # Context window size. Can increase. verbose=False #Reduce unnecessary verbosity ) llm_model = model # llama_cpp handles streaming natively. for token in llm( prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repeat_penalty=repeat_penalty, stream=True, # Ensure streaming is on stop=["<|im_end|>","<|endoftext|>","<|file_separator|>"], # Add appropriate stop tokens. Verify these! ): text_chunk = token["choices"][0]["text"] yield text_chunk except Exception as e: raise CustomExceptionHandling(e, sys) from e def clear_history(): """Clears the text input and output.""" return "", "" with gr.Blocks(title=title, theme="Ocean") as demo: # Use default theme if "Ocean" isn't available gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Row(): with gr.Column(scale=4): model_dropdown = gr.Dropdown( choices=[ "google.gemma-3-1b-pt.Q4_K_M.gguf", "google.gemma-3-4b-pt.Q4_K_M.gguf", "google.gemma-3-12b-pt.Q4_K_M.gguf", "google.gemma-3-27b-pt.Q4_K_M.gguf", # Add other models as needed and downloaded ], value="google.gemma-3-1b-pt.Q4_K_M.gguf", # Default model label="Model", info="Select the AI model", ) input_textbox = gr.Textbox( label="Input Prompt", placeholder="Enter text to continue...", lines=10, ) submit_button = gr.Button("Generate", variant="primary") clear_button = gr.Button("Clear Input") output_textbox = gr.Textbox( # Changed to Textbox for streaming label="Generated Text", lines=10, # Added lines for better display of longer outputs interactive=False # Output shouldn't be editable ) with gr.Column(scale=1): with gr.Accordion("Advanced Parameters", open=False): # open=False makes it initially collapsed max_tokens_slider = gr.Slider( minimum=32, maximum=8192, value=512, step=1, label="Max Tokens", info="Maximum length of generated text", ) temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.05, label="Temperature", info="Controls randomness (higher = more creative)", ) n_ctx_slider = gr.Slider( minimum=128, maximum=8192, value=512, step=128, label="Context Length", info="Controls the size of the model's 'memory'", ) top_p_slider = gr.Slider( minimum=0.05, maximum=1.0, value=0.95, step=0.05, label="Top-p", info="Nucleus sampling threshold", ) top_k_slider = gr.Slider( minimum=1, maximum=100, value=40, step=1, label="Top-k", info="Limit vocabulary choices to top K tokens", ) repeat_penalty_slider = gr.Slider( minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="Penalize repeated words (higher = less repetition)", ) def streaming_output(prompt, model, max_tokens, temperature, n_ctx, top_p, top_k, repeat_penalty): """Wraps the generator for Gradio.""" generated_text = "" for text_chunk in generate_text(prompt, model, max_tokens, temperature, n_ctx, top_p, top_k, repeat_penalty): generated_text += text_chunk yield generated_text submit_button.click( streaming_output, [ input_textbox, # Corrected order: prompt first model_dropdown, # model second max_tokens_slider, temperature_slider, n_ctx_slider, top_p_slider, top_k_slider, repeat_penalty_slider, ], output_textbox, ) clear_button.click(clear_history, [], [input_textbox, output_textbox]) #clear both input and output if __name__ == "__main__": demo.launch(debug=False, share=False) # Added share=False for clearer local-only run.