MNGames commited on
Commit
486c2a9
1 Parent(s): 916031b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -17
app.py CHANGED
@@ -1,27 +1,20 @@
1
  import gradio as gr
2
- from transformers import T5ForConditionalGeneration, T5Tokenizer
3
 
4
- # Load the T5 model and tokenizer for question generation
5
- model_name = "valhalla/t5-small-qg-prepend"
6
- tokenizer = T5Tokenizer.from_pretrained(model_name)
7
- model = T5ForConditionalGeneration.from_pretrained(model_name)
8
 
9
  def generate_questions(email_text):
10
- # Prepend "generate questions: " to the input text
11
- input_text = "generate questions: " + email_text
12
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
13
 
14
  # Generate questions
15
- outputs = model.generate(
16
- input_ids=input_ids,
17
- max_length=512,
18
- num_beams=4,
19
- early_stopping=True
20
- )
21
 
22
- # Decode the generated text
23
- questions = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- return questions
25
 
26
  # Create a Gradio interface
27
  iface = gr.Interface(
 
1
  import gradio as gr
2
+ from transformers import BartForConditionalGeneration, BartTokenizer
3
 
4
+ # Load the BART model and tokenizer
5
+ model_name = "facebook/bart-large-cnn"
6
+ tokenizer = BartTokenizer.from_pretrained(model_name)
7
+ model = BartForConditionalGeneration.from_pretrained(model_name)
8
 
9
  def generate_questions(email_text):
10
+ # Preprocess the email text for the BART model
11
+ inputs = tokenizer(email_text, return_tensors="pt", max_length=1024, truncation=True)
 
12
 
13
  # Generate questions
14
+ summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=50, early_stopping=True)
15
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
16
 
17
+ return summary
 
 
18
 
19
  # Create a Gradio interface
20
  iface = gr.Interface(