mouadenna commited on
Commit
d755dc2
1 Parent(s): 0bea081

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -20
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqGeneration
3
- import torch
4
 
5
- # Initialize model and tokenizer
6
- tokenizer = AutoTokenizer.from_pretrained("plguillou/t5-base-fr-sum-cnndm")
7
- model = AutoModelForSeq2SeqGeneration.from_pretrained("plguillou/t5-base-fr-sum-cnndm")
8
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
- model = model.to(device)
 
10
 
11
  def generate_summary(text: str, min_length: int = 100, max_length: int = 256) -> str:
12
  """
13
- Generate a summary of the input text using the T5 model
14
 
15
  Args:
16
  text (str): Input text to summarize
@@ -20,15 +20,9 @@ def generate_summary(text: str, min_length: int = 100, max_length: int = 256) ->
20
  Returns:
21
  str: Generated summary
22
  """
23
- # Tokenize the input text
24
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
25
- input_ids = inputs.input_ids.to(device)
26
- attention_mask = inputs.attention_mask.to(device)
27
-
28
- # Generate summary
29
- output = model.generate(
30
- input_ids,
31
- attention_mask=attention_mask,
32
  max_length=max_length,
33
  min_length=min_length,
34
  num_beams=4,
@@ -40,9 +34,8 @@ def generate_summary(text: str, min_length: int = 100, max_length: int = 256) ->
40
  repetition_penalty=1.2
41
  )
42
 
43
- # Decode and return the summary
44
- summary = tokenizer.decode(output[0], skip_special_tokens=True)
45
- return summary
46
 
47
  # Create the Gradio interface
48
  with gr.Blocks(title="French Text Summarizer") as demo:
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
3
 
4
+ # Initialize the summarization pipeline
5
+ summarizer = pipeline(
6
+ "summarization",
7
+ model="plguillou/t5-base-fr-sum-cnndm",
8
+ device="cuda" if gr.device=="cuda" else "cpu"
9
+ )
10
 
11
  def generate_summary(text: str, min_length: int = 100, max_length: int = 256) -> str:
12
  """
13
+ Generate a summary of the input text using the pipeline
14
 
15
  Args:
16
  text (str): Input text to summarize
 
20
  Returns:
21
  str: Generated summary
22
  """
23
+ # Generate summary using the pipeline
24
+ summary = summarizer(
25
+ text,
 
 
 
 
 
 
26
  max_length=max_length,
27
  min_length=min_length,
28
  num_beams=4,
 
34
  repetition_penalty=1.2
35
  )
36
 
37
+ # Return the generated summary text
38
+ return summary[0]['summary_text']
 
39
 
40
  # Create the Gradio interface
41
  with gr.Blocks(title="French Text Summarizer") as demo: