Chan-Y commited on
Commit
f1691d8
·
verified ·
1 Parent(s): 162dd8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -28
app.py CHANGED
@@ -1,40 +1,58 @@
 
 
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from langchain.prompts import PromptTemplate
4
  from langchain.chains.summarize import load_summarize_chain
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
6
  from langchain_core.documents import Document
7
  from pathlib import Path
 
8
 
9
- # Load the Mistral model from Hugging Face
10
- model_name = "mistralai/Mistral-7B-Instruct-v0.3"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
-
14
- # Define the text splitter and summarize chain
15
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
16
 
17
- # Define the summarization function
18
  def summarize(file, n_words):
19
  # Read the content of the uploaded file
20
  file_path = file.name
21
  with open(file_path, 'r', encoding='utf-8') as f:
22
  file_content = f.read()
23
-
24
- # Split the content into chunks
25
- chunks = text_splitter.create_documents([file_content])
26
-
27
- # Summarize each chunk and concatenate the results
28
- summaries = []
29
- for chunk in chunks:
30
- inputs = tokenizer(chunk.text, return_tensors="pt", max_length=512, truncation=True)
31
- summary_ids = model.generate(inputs["input_ids"], max_length=n_words, num_beams=4, early_stopping=True)
32
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
33
- summaries.append(summary)
34
-
35
- return " ".join(summaries)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Define the download summary function
38
  def download_summary(output_text):
39
  if output_text:
40
  file_path = Path('summary.txt')
@@ -43,7 +61,6 @@ def download_summary(output_text):
43
  return file_path
44
  else:
45
  return None
46
-
47
  def create_download_file(summary_text):
48
  file_path = download_summary(summary_text)
49
  return str(file_path) if file_path else None
@@ -54,21 +71,25 @@ with gr.Blocks() as demo:
54
 
55
  with gr.Row():
56
  with gr.Column():
57
- n_words = gr.Slider(minimum=50, maximum=500, step=50, label="Number of words")
58
  file = gr.File(label="Submit a file")
59
 
60
  with gr.Column():
61
- output_text = gr.Textbox(label="Summary will be printed here", lines=20)
62
 
63
  submit_button = gr.Button("Summarize")
64
  submit_button.click(summarize, inputs=[file, n_words], outputs=output_text)
65
 
 
 
 
 
 
66
  download_button = gr.Button("Download Summary")
67
  download_button.click(
68
  fn=create_download_file,
69
  inputs=[output_text],
70
  outputs=gr.File()
71
  )
72
-
73
  # Run the Gradio app
74
- demo.launch(share=True)
 
1
+ import warnings
2
+ warnings.simplefilter(action='ignore', category=FutureWarning)
3
+
4
  import gradio as gr
 
5
  from langchain.prompts import PromptTemplate
6
  from langchain.chains.summarize import load_summarize_chain
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.document_loaders import DirectoryLoader
9
  from langchain_core.documents import Document
10
  from pathlib import Path
11
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
12
 
13
+ llm = HuggingFaceEndpoint(
14
+ repo_id="mistralai/Mistral-7B-Instruct-v0.3",
15
+ task="text-generation",
16
+ max_new_tokens=1025,
17
+ do_sample=False,
18
+ )
19
+ llm_engine_hf = ChatHuggingFace(llm=llm)
20
 
 
21
  def summarize(file, n_words):
22
  # Read the content of the uploaded file
23
  file_path = file.name
24
  with open(file_path, 'r', encoding='utf-8') as f:
25
  file_content = f.read()
26
+ document = Document(file_content)
27
+ # Generate the summary
28
+ text = document.page_content
29
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=3000, chunk_overlap=200)
30
+ chunks = text_splitter.create_documents([text])
31
+ n_words = n_words
32
+ template = ''' [INST]
33
+ Your task is to summarize a long text into a concise summary of a specific number of words.
34
+
35
+ The summary you generate must be EXACTLY {N_WORDS} words long.
36
+
37
+ Before writing your final summary, first break down the key points of the text in a <scratchpad>. Identify the most important information that should be included in a summary of the specified length.
38
+
39
+ Then, write a summary that captures the core ideas and key details of the text. Start with an introductory sentence and then concisely summarize the main points in a logical order. Make sure to stay within the {{N_WORDS}} word limit.
40
+
41
+ Here is the long text to summarize:
42
+ Text:
43
+ {TEXT}
44
+
45
+
46
+ [/INST]
47
+ '''
48
+ prompt = PromptTemplate(
49
+ template=template,
50
+ input_variables=['TEXT', "N_WORDS"]
51
+ )
52
+ formatted_prompt = prompt.format(TEXT=text, N_WORDS=n_words)
53
+ output_summary = llm_engine_hf.invoke(formatted_prompt)
54
+ return output_summary.content
55
 
 
56
  def download_summary(output_text):
57
  if output_text:
58
  file_path = Path('summary.txt')
 
61
  return file_path
62
  else:
63
  return None
 
64
  def create_download_file(summary_text):
65
  file_path = download_summary(summary_text)
66
  return str(file_path) if file_path else None
 
71
 
72
  with gr.Row():
73
  with gr.Column():
74
+ n_words = gr.Slider(minimum=50, maximum=500, step=50, label="Number of words (approximately)")
75
  file = gr.File(label="Submit a file")
76
 
77
  with gr.Column():
78
+ output_text = gr.Textbox(label="Summary", lines=20)
79
 
80
  submit_button = gr.Button("Summarize")
81
  submit_button.click(summarize, inputs=[file, n_words], outputs=output_text)
82
 
83
+ def generate_file():
84
+ summary_text = output_text
85
+ file_path = download_summary(summary_text)
86
+ return file_path
87
+
88
  download_button = gr.Button("Download Summary")
89
  download_button.click(
90
  fn=create_download_file,
91
  inputs=[output_text],
92
  outputs=gr.File()
93
  )
 
94
  # Run the Gradio app
95
+ demo.launch(share=True)