Chan-Y commited on
Commit
162dd8b
·
verified ·
1 Parent(s): 0dd7ae7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -1
app.py CHANGED
@@ -1,3 +1,74 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- gr.load("models/mistralai/Mistral-7B-Instruct-v0.3").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.chains.summarize import load_summarize_chain
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_core.documents import Document
7
+ from pathlib import Path
8
 
9
+ # Load the Mistral model from Hugging Face
10
+ model_name = "mistralai/Mistral-7B-Instruct-v0.3"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
+
14
+ # Define the text splitter and summarize chain
15
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
16
+
17
+ # Define the summarization function
18
+ def summarize(file, n_words):
19
+ # Read the content of the uploaded file
20
+ file_path = file.name
21
+ with open(file_path, 'r', encoding='utf-8') as f:
22
+ file_content = f.read()
23
+
24
+ # Split the content into chunks
25
+ chunks = text_splitter.create_documents([file_content])
26
+
27
+ # Summarize each chunk and concatenate the results
28
+ summaries = []
29
+ for chunk in chunks:
30
+ inputs = tokenizer(chunk.text, return_tensors="pt", max_length=512, truncation=True)
31
+ summary_ids = model.generate(inputs["input_ids"], max_length=n_words, num_beams=4, early_stopping=True)
32
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
33
+ summaries.append(summary)
34
+
35
+ return " ".join(summaries)
36
+
37
+ # Define the download summary function
38
+ def download_summary(output_text):
39
+ if output_text:
40
+ file_path = Path('summary.txt')
41
+ with open(file_path, 'w', encoding='utf-8') as f:
42
+ f.write(output_text)
43
+ return file_path
44
+ else:
45
+ return None
46
+
47
+ def create_download_file(summary_text):
48
+ file_path = download_summary(summary_text)
49
+ return str(file_path) if file_path else None
50
+
51
+ # Create the Gradio interface
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("## Document Summarizer")
54
+
55
+ with gr.Row():
56
+ with gr.Column():
57
+ n_words = gr.Slider(minimum=50, maximum=500, step=50, label="Number of words")
58
+ file = gr.File(label="Submit a file")
59
+
60
+ with gr.Column():
61
+ output_text = gr.Textbox(label="Summary will be printed here", lines=20)
62
+
63
+ submit_button = gr.Button("Summarize")
64
+ submit_button.click(summarize, inputs=[file, n_words], outputs=output_text)
65
+
66
+ download_button = gr.Button("Download Summary")
67
+ download_button.click(
68
+ fn=create_download_file,
69
+ inputs=[output_text],
70
+ outputs=gr.File()
71
+ )
72
+
73
+ # Run the Gradio app
74
+ demo.launch(share=True)