Spaces:
Paused
Paused
try to change embedding model
Browse files
app.py
CHANGED
@@ -1,6 +1,9 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
from langchain.chains import RetrievalQA
|
3 |
-
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
|
4 |
from langchain.llms import OpenAI
|
5 |
from langchain.chat_models import ChatOpenAI
|
6 |
from langchain.vectorstores import Qdrant
|
@@ -16,7 +19,16 @@ PERSIST_DIR_NAME = "nvdajp-book"
|
|
16 |
|
17 |
|
18 |
def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
|
19 |
-
embeddings = OpenAIEmbeddings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
21 |
client = QdrantClient(url=db_url, api_key=db_api_key)
|
22 |
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
|
@@ -36,7 +48,7 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
|
|
36 |
"filter": {"category": option},
|
37 |
}
|
38 |
)
|
39 |
-
|
40 |
llm=ChatOpenAI(
|
41 |
model=model,
|
42 |
temperature=temperature
|
@@ -45,6 +57,7 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
|
|
45 |
retriever=retriever,
|
46 |
return_source_documents=True,
|
47 |
)
|
|
|
48 |
|
49 |
|
50 |
def get_related_url(metadata):
|
|
|
1 |
+
from time import time
|
2 |
import gradio as gr
|
3 |
from langchain.chains import RetrievalQA
|
4 |
+
# from langchain.embeddings import OpenAIEmbeddings
|
5 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
+
from langchain.embeddings import GPT4AllEmbeddings
|
7 |
from langchain.llms import OpenAI
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
from langchain.vectorstores import Qdrant
|
|
|
19 |
|
20 |
|
21 |
def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
|
22 |
+
# embeddings = OpenAIEmbeddings()
|
23 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
24 |
+
model_kwargs = {'device': 'cpu'}
|
25 |
+
encode_kwargs = {'normalize_embeddings': False}
|
26 |
+
embeddings = HuggingFaceEmbeddings(
|
27 |
+
model_name=model_name,
|
28 |
+
model_kwargs=model_kwargs,
|
29 |
+
encode_kwargs=encode_kwargs,
|
30 |
+
)
|
31 |
+
# embeddings = GPT4AllEmbeddings()
|
32 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
33 |
client = QdrantClient(url=db_url, api_key=db_api_key)
|
34 |
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
|
|
|
48 |
"filter": {"category": option},
|
49 |
}
|
50 |
)
|
51 |
+
result = RetrievalQA.from_chain_type(
|
52 |
llm=ChatOpenAI(
|
53 |
model=model,
|
54 |
temperature=temperature
|
|
|
57 |
retriever=retriever,
|
58 |
return_source_documents=True,
|
59 |
)
|
60 |
+
return result
|
61 |
|
62 |
|
63 |
def get_related_url(metadata):
|
config.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
|
3 |
|
4 |
-
SAAS =
|
5 |
|
6 |
|
7 |
def get_db_config():
|
|
|
1 |
import os
|
2 |
|
3 |
|
4 |
+
SAAS = False
|
5 |
|
6 |
|
7 |
def get_db_config():
|
store.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from langchain.document_loaders import ReadTheDocsLoader
|
2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
-
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
|
4 |
from langchain.vectorstores import Qdrant
|
5 |
# from qdrant_client import QdrantClient
|
6 |
from nvda_ug_loader import NVDAUserGuideLoader
|
@@ -35,7 +37,16 @@ def get_text_chunk(docs):
|
|
35 |
|
36 |
|
37 |
def store(texts):
|
38 |
-
embeddings = OpenAIEmbeddings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
40 |
# client = QdrantClient(url=db_url, api_key=db_api_key, prefer_grpc=True)
|
41 |
_ = Qdrant.from_documents(
|
|
|
1 |
from langchain.document_loaders import ReadTheDocsLoader
|
2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
# from langchain.embeddings import OpenAIEmbeddings
|
4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
+
from langchain.embeddings import GPT4AllEmbeddings
|
6 |
from langchain.vectorstores import Qdrant
|
7 |
# from qdrant_client import QdrantClient
|
8 |
from nvda_ug_loader import NVDAUserGuideLoader
|
|
|
37 |
|
38 |
|
39 |
def store(texts):
|
40 |
+
# embeddings = OpenAIEmbeddings()
|
41 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
42 |
+
model_kwargs = {'device': 'cuda'}
|
43 |
+
encode_kwargs = {'normalize_embeddings': False}
|
44 |
+
embeddings = HuggingFaceEmbeddings(
|
45 |
+
model_name=model_name,
|
46 |
+
model_kwargs=model_kwargs,
|
47 |
+
encode_kwargs=encode_kwargs,
|
48 |
+
)
|
49 |
+
# embeddings = GPT4AllEmbeddings()
|
50 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
51 |
# client = QdrantClient(url=db_url, api_key=db_api_key, prefer_grpc=True)
|
52 |
_ = Qdrant.from_documents(
|