import gradio as gr import pandas as pd import re from pyspark.sql import SparkSession, Window import pyspark.sql.functions as F from llama_cpp import Llama from loguru import logger # Import the logger from loguru # Create the models directory !mkdir -p ./models # Download the Llama model files !wget -O ./models/llama-2-7b-chat.ggmlv3.q8_0.bin https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q8_0.bin !wget -O ./models/llama-2-7b-chat.ggmlv3.q2_K.bin https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q2_K.bin # download "War and Peace" from Project Gutenberg !mkdir -p ./data !curl "https://gutenberg.org/cache/epub/2600/pg2600.txt" -o ./data/war_and_peace.txt # Define the Llama models MODEL_Q8_0 = Llama(model_path="./models/llama-2-7b-chat.ggmlv3.q8_0.bin", n_ctx=8192, n_batch=512) MODEL_Q2_K = Llama(model_path="./models/llama-2-7b-chat.ggmlv3.q2_K.bin", n_ctx=8192, n_batch=512) # Function to read the text file and create Spark DataFrame def create_spark_dataframe(text): # Get list of chapter strings chapter_list = [x for x in re.split('CHAPTER .+', text) if len(x) > 100] # Create Spark DataFrame spark = SparkSession.builder.appName("Counting word occurrences from a book, under a microscope.").config("spark.driver.memory", "4g").getOrCreate() spark.sparkContext.setLogLevel("WARN") df = spark.createDataFrame(pd.DataFrame({'text': chapter_list, 'chapter': range(1, len(chapter_list) + 1)})) return df # Function to summarize a chapter using the selected model def llama2_summarize(chapter_text, model_version): # Choose the model based on the model_version parameter if model_version == "q8_0": llm = MODEL_Q8_0 elif model_version == "q2_K": llm = MODEL_Q2_K else: return "Error: Invalid model_version." # Template for this model version template = """ [INST] <> You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> {INSERT_PROMPT_HERE} [/INST] """ # Create prompt prompt = 'Summarize the following novel chapter in a single sentence (less than 100 words): ' + chapter_text prompt = template.replace('INSERT_PROMPT_HERE', prompt) # Log the input chapter text and model_version logger.info(f"Input chapter text: {chapter_text}") logger.info(f"Selected model version: {model_version}") # Generate summary using the selected model output = llm(prompt, max_tokens=-1, echo=False, temperature=0.2, top_p=0.1) summary = output['choices'][0]['text'] # Log the generated summary logger.info(f"Generated summary: {summary}") return summary # Read the "War and Peace" text file and create Spark DataFrame with open('/content/data/war_and_peace.txt', 'r') as file: text = file.read() df_chapters = create_spark_dataframe(text) # Create summaries via Spark summaries = (df_chapters .limit(1) .groupby('chapter') .applyInPandas(llama2_summarize, schema='summary string, chapter int') .show(vertical=True, truncate=False) ) # Prompt for the file file_path = gr.inputs.File(label="Upload 'War and Peace' text file") # Choose the model version model_version = gr.inputs.Radio(["q8_0", "q2_K"], label="Choose Model Version") # Define the Gradio interface iface = gr.Interface( fn=llama2_summarize, inputs=[file_path, model_version], outputs="text", # Summary text live=False, capture_session=True, title="Llama2 Chapter Summarizer", description="Upload a text file of the novel 'War and Peace', and choose the model version ('q8_0' or 'q2_K') to get a summarized sentence for each chapter.", ) iface.launch();