kayrakan commited on
Commit
1da5cdb
1 Parent(s): c25e3ee

new version

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. __pycache__/app.cpython-312.pyc +0 -0
  3. app.py +83 -47
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /chroma_db/*
__pycache__/app.cpython-312.pyc CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
 
app.py CHANGED
@@ -1,68 +1,104 @@
 
 
1
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
- from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer
4
  import chromadb
5
- import os
6
  import json
 
 
7
 
8
  app = FastAPI()
9
 
10
- # Load the model
11
- model_name = "all-MiniLM-L6-v2" # A popular model for sentence embeddings
12
  model = SentenceTransformer(model_name)
13
 
14
- # Initialize ChromaDB
15
- chroma_client = chromadb.Client()
16
- collection_name = "json_lines"
17
 
 
 
 
18
 
19
- def embed_text(text):
20
- return model.encode(text).tolist() # Generate embeddings
 
21
 
 
 
 
 
22
 
23
- @app.post("/generate")
24
- async def generate_text(sentence: str = Form(...), file: UploadFile = File(...)):
25
- contents = await file.read()
26
- lines = json.loads(contents)
27
-
28
- # Check if the collection exists before attempting to delete it
29
- try:
30
- chroma_client.delete_collection(collection_name)
31
- except ValueError as e:
32
- if "does not exist" in str(e):
33
- pass # Ignore the error if the collection does not exist
34
-
35
- # Recreate the collection
36
- collection = chroma_client.get_or_create_collection(collection_name)
37
-
38
- # Process each line and store the embeddings in ChromaDB
39
- for i, line in enumerate(lines):
40
- text = line['text'] # Adjust this according to your JSON structure
41
  embedding = embed_text(text)
42
- metadata = {
43
- "id": i,
44
  "text": text,
45
- "duration": line.get("duration"),
46
- "lang": line.get("lang"),
47
- "offset": line.get("offset")
48
- }
49
- collection.add(embeddings=[embedding], metadatas=[metadata], ids=[str(i)])
50
-
51
- # Embed the query sentence
52
- query_embedding = embed_text(sentence)
53
-
54
- # Perform search in ChromaDB
55
- results = collection.query(query_embeddings=[query_embedding], n_results=5) # Adjust n_results as needed
56
 
57
- # Extract relevant lines from results
58
- relevant_lines = results["metadatas"][0]
59
-
60
- # Clear the collection after finding relevant lines
61
- chroma_client.delete_collection(collection_name)
62
-
63
- return {"relevant_lines": relevant_lines}
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  @app.get("/")
67
  def greet_json():
68
  return {"Hello": "World!"}
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
 
4
  from sentence_transformers import SentenceTransformer
5
  import chromadb
 
6
  import json
7
+ from typing import List
8
+ from functools import lru_cache
9
 
10
  app = FastAPI()
11
 
12
+ # Load the multilingual model
13
+ model_name = "paraphrase-multilingual-mpnet-base-v2" # This model supports 50+ languages
14
  model = SentenceTransformer(model_name)
15
 
16
+ # Initialize persistent ChromaDB
17
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
18
+ collection_name = "transcriptions"
19
 
20
+ # Setup logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
 
24
+ @lru_cache(maxsize=1000)
25
+ def embed_text(text: str):
26
+ return model.encode(text).tolist()
27
 
28
+ async def process_batch(batch, collection, start_index):
29
+ embeddings = []
30
+ metadatas = []
31
+ ids = []
32
 
33
+ for i, item in enumerate(batch):
34
+ text = item['text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  embedding = embed_text(text)
36
+ embeddings.append(embedding)
37
+ metadatas.append({
38
  "text": text,
39
+ "duration": item.get("duration"),
40
+ "offset": item.get("offset"),
41
+ "lang": item.get("lang")
42
+ })
43
+ # Create a unique ID using the start_index and the current item index
44
+ unique_id = f"{start_index + i}_{item.get('offset')}"
45
+ ids.append(unique_id)
 
 
 
 
46
 
47
+ collection.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
 
 
 
 
 
 
48
 
49
+ @app.post("/generate")
50
+ async def generate_text(sentence: str = Form(...), file: UploadFile = File(...), delete_after: bool = True):
51
+ try:
52
+ contents = await file.read()
53
+ transcription = json.loads(contents)
54
+
55
+ # Get or create the collection
56
+ try:
57
+ collection = chroma_client.get_collection(collection_name)
58
+ # Clear existing data
59
+ collection.delete(where={})
60
+ except ValueError:
61
+ collection = chroma_client.create_collection(collection_name)
62
+
63
+ # Process in batches
64
+ batch_size = 100
65
+ tasks = []
66
+ for i in range(0, len(transcription), batch_size):
67
+ batch = transcription[i:i + batch_size]
68
+ task = asyncio.create_task(process_batch(batch, collection, i))
69
+ tasks.append(task)
70
+ await asyncio.gather(*tasks)
71
+
72
+ # Embed the query sentence
73
+ query_embedding = embed_text(sentence)
74
+
75
+ # Perform search in ChromaDB
76
+ results = collection.query(query_embeddings=[query_embedding], n_results=5)
77
+
78
+ # Extract relevant lines from results
79
+ relevant_lines = results["metadatas"][0]
80
+
81
+ logger.info(f"Query results: {relevant_lines}")
82
+
83
+ return {"relevant_lines": relevant_lines}
84
+
85
+ except Exception as e:
86
+ logger.error(f"Error during processing: {e}")
87
+ raise HTTPException(status_code=500, detail="Internal Server Error")
88
+ finally:
89
+ if delete_after:
90
+ # Clear the collection after finding relevant lines or if an error occurred
91
+ try:
92
+ chroma_client.delete_collection(collection_name)
93
+ except ValueError:
94
+ pass # Collection might have already been deleted or doesn't exist
95
 
96
  @app.get("/")
97
  def greet_json():
98
  return {"Hello": "World!"}
99
+
100
+
101
+ if __name__ == "__main__":
102
+ import uvicorn
103
+
104
+ uvicorn.run(app, host="0.0.0.0", port=8000)