Marroco93 commited on
Commit
bcee5ff
1 Parent(s): 10d17a3

no message

Browse files
Files changed (1) hide show
  1. main.py +16 -54
main.py CHANGED
@@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
6
  from huggingface_hub import InferenceClient
7
  import uvicorn
8
- from typing import Generator
9
  import json # Asegúrate de que esta línea esté al principio del archivo
10
  import nltk
11
  import os
@@ -81,68 +81,30 @@ async def generate_text(item: Item):
81
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
82
 
83
 
84
- # Load spaCy model
85
- nlp = spacy.load("en_core_web_sm")
86
 
87
- class TextRequest(BaseModel):
88
- text: str
89
-
90
- def preprocess_text(text: str) -> str:
91
- # Normalize whitespace and strip punctuation
92
- text = re.sub(r'\s+', ' ', text.strip())
93
- text = re.sub(r'[^\w\s]', '', text)
94
- return text
95
-
96
- def embed_text(text: str) -> np.ndarray:
97
- # Load the JinaAI/jina-embeddings-v2-base-en model
98
- model_name = "JinaAI/jina-embeddings-v2-base-en"
99
- tokenizer = AutoTokenizer.from_pretrained(model_name)
100
- model = AutoModel.from_pretrained(model_name)
101
-
102
- inputs = tokenizer(text, return_tensors='pt')
103
- embeddings = model(**inputs).pooler_output.numpy()
104
-
105
- return embeddings
106
 
107
- def semantic_matching(text, context):
108
- text_embeddings = embed_text(text)
109
- context_embeddings = [embed_text(ctx) for ctx in context]
110
-
111
- # Calculate cosine similarity between text and context embeddings
112
- similarities = np.dot(text_embeddings, context_embeddings.T)
113
-
114
- # Find the most similar sentence in the context
115
- most_similar_idx = np.argmax(similarities)
116
-
117
- return context[most_similar_idx]
118
-
119
- def handle_endpoint(text):
120
- # Define your large context here
121
- context = [
122
- "This is a sample context sentence 1.",
123
- "Another context sentence to provide additional information.",
124
- "This context sentence introduces a new topic.",
125
- "Some additional details about the new topic are provided here.",
126
- "Context sentences can be added or removed as needed.",
127
- "The context should cover a range of topics and provide relevant information.",
128
- "Make sure the context is diverse and representative of the domain.",
129
- ]
130
-
131
- # Perform semantic matching to retrieve the most relevant portion of the context
132
- relevant_context = semantic_matching(text, context)
133
 
134
- return relevant_context
 
 
135
 
 
136
  @app.post("/process_document")
137
  async def process_document(request: TextRequest):
138
  try:
139
- processed_text = preprocess_text(request.text)
140
- embedded_text = embed_text(processed_text)
141
- relevant_context = handle_endpoint(processed_text)
 
 
 
 
142
 
143
  return {
144
- "embedded_text": embedded_text.tolist(),
145
- "relevant_context": relevant_context
146
  }
147
  except Exception as e:
148
  print(f"Error during document processing: {e}")
 
5
  from pydantic import BaseModel
6
  from huggingface_hub import InferenceClient
7
  import uvicorn
8
+ from typing import Generator, List
9
  import json # Asegúrate de que esta línea esté al principio del archivo
10
  import nltk
11
  import os
 
81
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
82
 
83
 
 
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # Define request model
87
+ class TextRequest(BaseModel):
88
+ text: List[str] # Expect a list of text segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # Load Longformer model and tokenizer
91
+ tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
92
+ model = AutoModelForSequenceClassification.from_pretrained("allenai/longformer-base-4096")
93
 
94
+ # Endpoint to process the document and return embeddings for each segment
95
  @app.post("/process_document")
96
  async def process_document(request: TextRequest):
97
  try:
98
+ embeddings_list = []
99
+ for text_segment in request.text:
100
+ # Process each segment individually
101
+ inputs = tokenizer(text_segment, return_tensors="pt", padding=True, truncation=True, max_length=4096)
102
+ outputs = model(**inputs)
103
+ embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy()
104
+ embeddings_list.append(embeddings.tolist()) # Store embeddings for each segment
105
 
106
  return {
107
+ "embeddings": embeddings_list
 
108
  }
109
  except Exception as e:
110
  print(f"Error during document processing: {e}")