Spaces:
Runtime error
Runtime error
Andyrasika
commited on
Commit
•
7f1450b
1
Parent(s):
98de4ff
app.py
Browse files
app.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import re
|
4 |
+
from pyspark.sql import SparkSession, Window
|
5 |
+
import pyspark.sql.functions as F
|
6 |
+
from llama_cpp import Llama
|
7 |
+
from loguru import logger # Import the logger from loguru
|
8 |
+
|
9 |
+
# Create the models directory
|
10 |
+
!mkdir -p ./models
|
11 |
+
|
12 |
+
# Download the Llama model files
|
13 |
+
!wget -O ./models/llama-2-7b-chat.ggmlv3.q8_0.bin https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q8_0.bin
|
14 |
+
!wget -O ./models/llama-2-7b-chat.ggmlv3.q2_K.bin https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q2_K.bin
|
15 |
+
|
16 |
+
# download "War and Peace" from Project Gutenberg
|
17 |
+
!mkdir -p ./data
|
18 |
+
!curl "https://gutenberg.org/cache/epub/2600/pg2600.txt" -o ./data/war_and_peace.txt
|
19 |
+
|
20 |
+
|
21 |
+
# Define the Llama models
|
22 |
+
MODEL_Q8_0 = Llama(model_path="./models/llama-2-7b-chat.ggmlv3.q8_0.bin", n_ctx=8192, n_batch=512)
|
23 |
+
MODEL_Q2_K = Llama(model_path="./models/llama-2-7b-chat.ggmlv3.q2_K.bin", n_ctx=8192, n_batch=512)
|
24 |
+
|
25 |
+
# Function to read the text file and create Spark DataFrame
|
26 |
+
def create_spark_dataframe(text):
|
27 |
+
# Get list of chapter strings
|
28 |
+
chapter_list = [x for x in re.split('CHAPTER .+', text) if len(x) > 100]
|
29 |
+
|
30 |
+
# Create Spark DataFrame
|
31 |
+
spark = SparkSession.builder.appName("Counting word occurrences from a book, under a microscope.").config("spark.driver.memory", "4g").getOrCreate()
|
32 |
+
spark.sparkContext.setLogLevel("WARN")
|
33 |
+
df = spark.createDataFrame(pd.DataFrame({'text': chapter_list, 'chapter': range(1, len(chapter_list) + 1)}))
|
34 |
+
|
35 |
+
return df
|
36 |
+
|
37 |
+
# Function to summarize a chapter using the selected model
|
38 |
+
def llama2_summarize(chapter_text, model_version):
|
39 |
+
# Choose the model based on the model_version parameter
|
40 |
+
if model_version == "q8_0":
|
41 |
+
llm = MODEL_Q8_0
|
42 |
+
elif model_version == "q2_K":
|
43 |
+
llm = MODEL_Q2_K
|
44 |
+
else:
|
45 |
+
return "Error: Invalid model_version."
|
46 |
+
|
47 |
+
# Template for this model version
|
48 |
+
template = """
|
49 |
+
[INST] <<SYS>>
|
50 |
+
You are a helpful, respectful and honest assistant.
|
51 |
+
Always answer as helpfully as possible, while being safe.
|
52 |
+
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
|
53 |
+
Please ensure that your responses are socially unbiased and positive in nature.
|
54 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
|
55 |
+
If you don't know the answer to a question, please don't share false information.
|
56 |
+
<</SYS>>
|
57 |
+
{INSERT_PROMPT_HERE} [/INST]
|
58 |
+
"""
|
59 |
+
|
60 |
+
# Create prompt
|
61 |
+
prompt = 'Summarize the following novel chapter in a single sentence (less than 100 words): ' + chapter_text
|
62 |
+
prompt = template.replace('INSERT_PROMPT_HERE', prompt)
|
63 |
+
|
64 |
+
# Log the input chapter text and model_version
|
65 |
+
logger.info(f"Input chapter text: {chapter_text}")
|
66 |
+
logger.info(f"Selected model version: {model_version}")
|
67 |
+
|
68 |
+
# Generate summary using the selected model
|
69 |
+
output = llm(prompt, max_tokens=-1, echo=False, temperature=0.2, top_p=0.1)
|
70 |
+
summary = output['choices'][0]['text']
|
71 |
+
|
72 |
+
# Log the generated summary
|
73 |
+
logger.info(f"Generated summary: {summary}")
|
74 |
+
|
75 |
+
return summary
|
76 |
+
|
77 |
+
# Read the "War and Peace" text file and create Spark DataFrame
|
78 |
+
with open('/content/data/war_and_peace.txt', 'r') as file:
|
79 |
+
text = file.read()
|
80 |
+
df_chapters = create_spark_dataframe(text)
|
81 |
+
|
82 |
+
# Create summaries via Spark
|
83 |
+
summaries = (df_chapters
|
84 |
+
.limit(1)
|
85 |
+
.groupby('chapter')
|
86 |
+
.applyInPandas(llama2_summarize, schema='summary string, chapter int')
|
87 |
+
.show(vertical=True, truncate=False)
|
88 |
+
)
|
89 |
+
|
90 |
+
# Prompt for the file
|
91 |
+
file_path = gr.inputs.File(label="Upload 'War and Peace' text file")
|
92 |
+
|
93 |
+
# Choose the model version
|
94 |
+
model_version = gr.inputs.Radio(["q8_0", "q2_K"], label="Choose Model Version")
|
95 |
+
|
96 |
+
# Define the Gradio interface
|
97 |
+
iface = gr.Interface(
|
98 |
+
fn=llama2_summarize,
|
99 |
+
inputs=[file_path, model_version],
|
100 |
+
outputs="text", # Summary text
|
101 |
+
live=False,
|
102 |
+
capture_session=True,
|
103 |
+
title="Llama2 Chapter Summarizer",
|
104 |
+
description="Upload a text file of the novel 'War and Peace', and choose the model version ('q8_0' or 'q2_K') to get a summarized sentence for each chapter.",
|
105 |
+
)
|
106 |
+
|
107 |
+
iface.launch();
|