terapyon commited on
Commit
6125df0
·
1 Parent(s): 9022e07

try to change embedding model

Browse files
Files changed (3) hide show
  1. app.py +16 -3
  2. config.py +1 -1
  3. store.py +13 -2
app.py CHANGED
@@ -1,6 +1,9 @@
 
1
  import gradio as gr
2
  from langchain.chains import RetrievalQA
3
- from langchain.embeddings import OpenAIEmbeddings
 
 
4
  from langchain.llms import OpenAI
5
  from langchain.chat_models import ChatOpenAI
6
  from langchain.vectorstores import Qdrant
@@ -16,7 +19,16 @@ PERSIST_DIR_NAME = "nvdajp-book"
16
 
17
 
18
  def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
19
- embeddings = OpenAIEmbeddings()
 
 
 
 
 
 
 
 
 
20
  db_url, db_api_key, db_collection_name = DB_CONFIG
21
  client = QdrantClient(url=db_url, api_key=db_api_key)
22
  db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
@@ -36,7 +48,7 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
36
  "filter": {"category": option},
37
  }
38
  )
39
- return RetrievalQA.from_chain_type(
40
  llm=ChatOpenAI(
41
  model=model,
42
  temperature=temperature
@@ -45,6 +57,7 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
45
  retriever=retriever,
46
  return_source_documents=True,
47
  )
 
48
 
49
 
50
  def get_related_url(metadata):
 
1
+ from time import time
2
  import gradio as gr
3
  from langchain.chains import RetrievalQA
4
+ # from langchain.embeddings import OpenAIEmbeddings
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.embeddings import GPT4AllEmbeddings
7
  from langchain.llms import OpenAI
8
  from langchain.chat_models import ChatOpenAI
9
  from langchain.vectorstores import Qdrant
 
19
 
20
 
21
  def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
22
+ # embeddings = OpenAIEmbeddings()
23
+ model_name = "sentence-transformers/all-mpnet-base-v2"
24
+ model_kwargs = {'device': 'cpu'}
25
+ encode_kwargs = {'normalize_embeddings': False}
26
+ embeddings = HuggingFaceEmbeddings(
27
+ model_name=model_name,
28
+ model_kwargs=model_kwargs,
29
+ encode_kwargs=encode_kwargs,
30
+ )
31
+ # embeddings = GPT4AllEmbeddings()
32
  db_url, db_api_key, db_collection_name = DB_CONFIG
33
  client = QdrantClient(url=db_url, api_key=db_api_key)
34
  db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
 
48
  "filter": {"category": option},
49
  }
50
  )
51
+ result = RetrievalQA.from_chain_type(
52
  llm=ChatOpenAI(
53
  model=model,
54
  temperature=temperature
 
57
  retriever=retriever,
58
  return_source_documents=True,
59
  )
60
+ return result
61
 
62
 
63
  def get_related_url(metadata):
config.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
 
3
 
4
- SAAS = True
5
 
6
 
7
  def get_db_config():
 
1
  import os
2
 
3
 
4
+ SAAS = False
5
 
6
 
7
  def get_db_config():
store.py CHANGED
@@ -1,6 +1,8 @@
1
  from langchain.document_loaders import ReadTheDocsLoader
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
- from langchain.embeddings import OpenAIEmbeddings
 
 
4
  from langchain.vectorstores import Qdrant
5
  # from qdrant_client import QdrantClient
6
  from nvda_ug_loader import NVDAUserGuideLoader
@@ -35,7 +37,16 @@ def get_text_chunk(docs):
35
 
36
 
37
  def store(texts):
38
- embeddings = OpenAIEmbeddings()
 
 
 
 
 
 
 
 
 
39
  db_url, db_api_key, db_collection_name = DB_CONFIG
40
  # client = QdrantClient(url=db_url, api_key=db_api_key, prefer_grpc=True)
41
  _ = Qdrant.from_documents(
 
1
  from langchain.document_loaders import ReadTheDocsLoader
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ # from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.embeddings import GPT4AllEmbeddings
6
  from langchain.vectorstores import Qdrant
7
  # from qdrant_client import QdrantClient
8
  from nvda_ug_loader import NVDAUserGuideLoader
 
37
 
38
 
39
  def store(texts):
40
+ # embeddings = OpenAIEmbeddings()
41
+ model_name = "sentence-transformers/all-mpnet-base-v2"
42
+ model_kwargs = {'device': 'cuda'}
43
+ encode_kwargs = {'normalize_embeddings': False}
44
+ embeddings = HuggingFaceEmbeddings(
45
+ model_name=model_name,
46
+ model_kwargs=model_kwargs,
47
+ encode_kwargs=encode_kwargs,
48
+ )
49
+ # embeddings = GPT4AllEmbeddings()
50
  db_url, db_api_key, db_collection_name = DB_CONFIG
51
  # client = QdrantClient(url=db_url, api_key=db_api_key, prefer_grpc=True)
52
  _ = Qdrant.from_documents(