loayshabet commited on
Commit
2cc6057
Β·
verified Β·
1 Parent(s): 9dbc598

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -122
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import feedparser
4
  from datetime import datetime, timedelta
5
  import pytz
@@ -8,133 +8,57 @@ import hashlib
8
  import threading
9
  import logging
10
 
 
 
 
11
  # Set up logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
- # Global settings
16
- SUMMARIZER_MODELS = {
17
- "Default (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
18
- "Free Model (distilbart-cnn-6-6)": "sshleifer/distilbart-cnn-6-6"
19
- }
20
- CACHE_SIZE = 500
21
- RSS_FETCH_INTERVAL = timedelta(hours=8)
22
- ARTICLE_LIMIT = 5
23
-
24
- # Updated categories and news sources
25
- CATEGORIES = ["Technology", "Business", "Science", "World News", "Sports", "Health"]
26
- NEWS_SOURCES = {
27
- "Technology": {
28
- "TheNewYorkTimes": "https://rss.nytimes.com/services/xml/rss/nyt/Technology.xml",
29
- "reutersagency": "https://www.reutersagency.com/feed/?best-topics=tech&post_type=best",
30
- "alarabiya arabic": "https://www.alarabiya.net/feed/rss2/ar/technology.xml",
31
- },
32
- "Business": {
33
- "TheNewYorkTimes": "https://rss.nytimes.com/services/xml/rss/nyt/Business.xml",
34
- "reutersagency": "https://www.reutersagency.com/feed/?best-topics=business-finance&post_type=best",
35
- "alwatanvoice arabic": "https://feeds.alwatanvoice.com/ar/business.xml",
36
- },
37
- "Science": {
38
- "TheNewYorkTimes": "https://rss.nytimes.com/services/xml/rss/nyt/Science.xml"
39
- },
40
- "World News": {
41
- "TheNewYorkTimes": "https://rss.nytimes.com/services/xml/rss/nyt/World.xml",
42
- "BBC": "http://feeds.bbci.co.uk/news/world/rss.xml",
43
- "CNN": "http://rss.cnn.com/rss/edition_world.rss",
44
- "reutersagency": "https://www.reutersagency.com/feed/?taxonomy=best-regions&post_type=best",
45
- "france24 arabic": "https://www.france24.com/ar/rss",
46
- "aljazera arabic": "https://www.aljazeera.net/aljazeerarss/a7c186be-1baa-4bd4-9d80-a84db769f779/73d0e1b4-532f-45ef-b135-bfdff8b8cab9",
47
- },
48
- "Sports": {
49
- "TheNewYorkTimes": "https://rss.nytimes.com/services/xml/rss/nyt/Sports.xml",
50
- "reutersagency": "https://www.reutersagency.com/feed/?best-topics=sports&post_type=best",
51
- "france24 arabic": "https://www.france24.com/ar/%D8%B1%D9%8A%D8%A7%D8%B6%D8%A9/rss",
52
- },
53
- "Health": {
54
- "TheNewYorkTimes": "https://rss.nytimes.com/services/xml/rss/nyt/Health.xml",
55
- "politico": "http://rss.politico.com/healthcare.xml",
56
- "reutersagency": "https://www.reutersagency.com/feed/?best-topics=health&post_type=best"
57
- },
58
- }
59
-
60
- class NewsCache:
61
- def __init__(self, size):
62
- self.cache = {}
63
- self.size = size
64
- self.lock = threading.Lock()
65
-
66
- def get(self, key):
67
- with self.lock:
68
- return self.cache.get(key)
69
 
70
- def set(self, key, value):
71
- with self.lock:
72
- if len(self.cache) >= self.size:
73
- oldest_key = next(iter(self.cache))
74
- del self.cache[oldest_key]
75
- self.cache[key] = value
76
-
77
- cache = NewsCache(CACHE_SIZE)
78
-
79
- def fetch_rss_news(tech_sources, business_sources, science_sources, world_sources, sports_sources, health_sources):
80
- articles = []
81
- cutoff_time = datetime.now(pytz.UTC) - RSS_FETCH_INTERVAL
82
-
83
- # Create a mapping of selected sources
84
- category_sources = {
85
- "Technology": tech_sources if tech_sources else [],
86
- "Business": business_sources if business_sources else [],
87
- "Science": science_sources if science_sources else [],
88
- "World News": world_sources if world_sources else [],
89
- "Sports": sports_sources if sports_sources else [],
90
- "Health": health_sources if health_sources else []
91
- }
92
 
93
- logger.info(f"Selected sources: {category_sources}")
94
-
95
- for category, sources in category_sources.items():
96
- if not sources: # Skip if no sources selected for this category
97
- continue
98
-
99
- logger.info(f"Processing category: {category} with sources: {sources}")
100
-
101
- for source in sources:
102
- if source in NEWS_SOURCES[category]:
103
- url = NEWS_SOURCES[category][source]
104
- try:
105
- logger.info(f"Fetching from URL: {url}")
106
- feed = feedparser.parse(url)
107
-
108
- if hasattr(feed, 'status') and feed.status != 200:
109
- logger.warning(f"Failed to fetch feed from {url}. Status: {feed.status}")
110
- continue
111
-
112
- for entry in feed.entries:
113
- try:
114
- published = datetime(*entry.published_parsed[:6], tzinfo=pytz.UTC)
115
- if published > cutoff_time:
116
- articles.append({
117
- "title": entry.title,
118
- "description": BeautifulSoup(entry.description, "html.parser").get_text(),
119
- "link": entry.link,
120
- "category": category,
121
- "source": source,
122
- "published": published
123
- })
124
- except (AttributeError, TypeError) as e:
125
- logger.error(f"Error processing entry: {str(e)}")
126
- continue
127
-
128
- except Exception as e:
129
- logger.error(f"Error fetching feed from {url}: {str(e)}")
130
- continue
131
-
132
- logger.info(f"Total articles fetched: {len(articles)}")
133
- articles = sorted(articles, key=lambda x: x["published"], reverse=True)[:ARTICLE_LIMIT]
134
- return articles
135
-
136
- def summarize_text(text, model_name):
137
  try:
 
 
 
 
 
138
  summarizer = pipeline("summarization", model=model_name, device=-1)
139
  content_hash = hashlib.md5(text.encode()).hexdigest()
140
  cached_summary = cache.get(content_hash)
@@ -156,7 +80,7 @@ def summarize_articles(articles, model_name):
156
  summaries = []
157
  for article in articles:
158
  content = article["description"]
159
- summary = summarize_text(content, model_name)
160
  summaries.append(f"""
161
  πŸ“° {article['title']}
162
  - πŸ“ Category: {article['category']}
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForSeq2SeqGeneration, AutoTokenizer
3
  import feedparser
4
  from datetime import datetime, timedelta
5
  import pytz
 
8
  import threading
9
  import logging
10
 
11
+ # Add this to your imports
12
+ from transformers import MarianMTModel, MarianTokenizer
13
+
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Add translation model configuration
19
+ TRANSLATION_MODEL = "Helsinki-NLP/opus-mt-ar-en"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ class Translator:
22
+ def __init__(self):
23
+ self.model = None
24
+ self.tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ def load_model(self):
27
+ if self.model is None:
28
+ try:
29
+ self.tokenizer = MarianTokenizer.from_pretrained(TRANSLATION_MODEL)
30
+ self.model = MarianMTModel.from_pretrained(TRANSLATION_MODEL)
31
+ logger.info("Translation model loaded successfully")
32
+ except Exception as e:
33
+ logger.error(f"Error loading translation model: {str(e)}")
34
+ raise
35
+
36
+ def translate(self, text):
37
+ try:
38
+ self.load_model()
39
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
40
+ translated = self.model.generate(**inputs)
41
+ return self.tokenizer.decode(translated[0], skip_special_tokens=True)
42
+ except Exception as e:
43
+ logger.error(f"Translation error: {str(e)}")
44
+ return text
45
+
46
+ # Initialize translator
47
+ translator = Translator()
48
+
49
+ # Rest of your existing configurations...
50
+ [Your existing SUMMARIZER_MODELS, CACHE_SIZE, RSS_FETCH_INTERVAL, ARTICLE_LIMIT, CATEGORIES, and NEWS_SOURCES definitions]
51
+
52
+ def is_arabic_source(source_name):
53
+ return any(arabic_indicator in source_name.lower() for arabic_indicator in ['arabic', 'alarabiya', 'aljazeera', 'alwatanvoice'])
54
+
55
+ def summarize_text(text, model_name, source):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
+ # Translate if it's an Arabic source
58
+ if is_arabic_source(source):
59
+ logger.info("Translating Arabic content before summarization")
60
+ text = translator.translate(text)
61
+
62
  summarizer = pipeline("summarization", model=model_name, device=-1)
63
  content_hash = hashlib.md5(text.encode()).hexdigest()
64
  cached_summary = cache.get(content_hash)
 
80
  summaries = []
81
  for article in articles:
82
  content = article["description"]
83
+ summary = summarize_text(content, model_name, article['source'])
84
  summaries.append(f"""
85
  πŸ“° {article['title']}
86
  - πŸ“ Category: {article['category']}