habulaj commited on
Commit
7b3ef9f
·
verified ·
1 Parent(s): 1516a70

Update routers/textclas.py

Browse files
Files changed (1) hide show
  1. routers/textclas.py +22 -27
routers/textclas.py CHANGED
@@ -1,17 +1,8 @@
1
  from fastapi import APIRouter, Query, HTTPException
2
- from sentence_transformers import SentenceTransformer
3
- import pickle
4
- from sklearn.metrics.pairwise import cosine_similarity
5
 
6
- # Carrega o modelo de embeddings
7
- model = SentenceTransformer('all-MiniLM-L6-v2')
8
-
9
- # Carrega os embeddings e palavras-chave pré-calculados
10
- with open('keywords_embeddings.pkl', 'rb') as f:
11
- keywords_embeddings = pickle.load(f)
12
-
13
- with open('keywords_list.pkl', 'rb') as f:
14
- keyword_categories = pickle.load(f)
15
 
16
  router = APIRouter()
17
 
@@ -21,27 +12,31 @@ def extract_keywords(
21
  num_keywords: int = Query(5, description="Número de palavras-chave a serem retornadas", ge=1, le=20)
22
  ):
23
  """
24
- Extrai palavras-chave relevantes de um texto com base em similaridade semântica.
25
  """
26
  try:
27
- # Gera o embedding do texto
28
- text_embedding = model.encode([text])
29
-
30
- # Calcula a similaridade entre o texto e as palavras-chave
31
- similarities = cosine_similarity(text_embedding, keywords_embeddings)
32
-
33
- # Ordena as palavras-chave com base na similaridade
34
- sorted_indices = similarities[0].argsort()[::-1]
35
-
36
- # Retorna as palavras-chave com maior similaridade
37
- top_keywords = [keyword_categories[i] for i in sorted_indices[:num_keywords]]
38
-
39
  return {
40
  "text": text,
41
  "num_keywords": num_keywords,
42
- "keywords": top_keywords
43
  }
44
-
 
 
 
 
 
45
  except Exception as e:
46
  raise HTTPException(
47
  status_code=500,
 
1
  from fastapi import APIRouter, Query, HTTPException
2
+ from keybert import KeyBERT
 
 
3
 
4
+ # Inicializa o modelo KeyBERT com DistilBERT (mais rápido que BERT completo)
5
+ kw_model = KeyBERT(model='distilbert-base-nli-mean-tokens')
 
 
 
 
 
 
 
6
 
7
  router = APIRouter()
8
 
 
12
  num_keywords: int = Query(5, description="Número de palavras-chave a serem retornadas", ge=1, le=20)
13
  ):
14
  """
15
+ Extrai palavras-chave relevantes de um texto.
16
  """
17
  try:
18
+ # Extrai palavras-chave
19
+ keywords = kw_model.extract_keywords(
20
+ text,
21
+ keyphrase_ngram_range=(1, 2),
22
+ stop_words='english',
23
+ top_n=num_keywords
24
+ )
25
+
26
+ # Formata o retorno
27
+ keyword_list = [kw[0] for kw in keywords]
28
+
 
29
  return {
30
  "text": text,
31
  "num_keywords": num_keywords,
32
+ "keywords": keyword_list
33
  }
34
+
35
+ except ValueError as ve:
36
+ raise HTTPException(
37
+ status_code=400,
38
+ detail=f"Invalid input: {str(ve)}"
39
+ )
40
  except Exception as e:
41
  raise HTTPException(
42
  status_code=500,