terapyon commited on
Commit
1e444f6
·
1 Parent(s): ed3c145

dev/streamlit-cache (#13)

Browse files

- added streamlit cache (60434a827f1979f49eea9913f3bfc0c6f58a2957)

Files changed (2) hide show
  1. app.py +34 -22
  2. requirements.txt +1 -0
app.py CHANGED
@@ -20,26 +20,37 @@ from qdrant_client import QdrantClient
20
  from config import DB_CONFIG, DB_E5_CONFIG
21
 
22
 
23
- E5_MODEL_NAME = "intfloat/multilingual-e5-large"
24
- E5_MODEL_KWARGS = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"}
25
- E5_ENCODE_KWARGS = {"normalize_embeddings": False}
26
- E5_EMBEDDINGS = HuggingFaceEmbeddings(
27
- model_name=E5_MODEL_NAME,
28
- model_kwargs=E5_MODEL_KWARGS,
29
- encode_kwargs=E5_ENCODE_KWARGS,
30
- )
31
-
32
- if False and torch.cuda.is_available(): # TODO: for local debug
33
- RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
34
- RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
35
- RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
36
- RINNA_MODEL_NAME,
37
- load_in_8bit=True,
38
- torch_dtype=torch.float16,
39
- device_map="auto",
40
  )
41
- else:
42
- RINNA_MODEL = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  def _get_config_and_embeddings(collection_name: str | None) -> tuple:
@@ -54,7 +65,8 @@ def _get_config_and_embeddings(collection_name: str | None) -> tuple:
54
  return db_config, embeddings
55
 
56
 
57
- def _get_rinna_llm(temperature: float):
 
58
  if RINNA_MODEL is not None:
59
  pipe = pipeline(
60
  "text-generation",
@@ -95,7 +107,7 @@ def get_retrieval_qa(
95
  model_name: str | None,
96
  temperature: float,
97
  option: str | None,
98
- ) -> RetrievalQA:
99
  db_config, embeddings = _get_config_and_embeddings(collection_name)
100
  db_url, db_api_key, db_collection_name = db_config
101
  client = QdrantClient(url=db_url, api_key=db_api_key)
@@ -125,7 +137,7 @@ def get_retrieval_qa(
125
  return result
126
 
127
 
128
- def get_related_url(metadata):
129
  urls = set()
130
  for m in metadata:
131
  # p = m['source']
 
20
  from config import DB_CONFIG, DB_E5_CONFIG
21
 
22
 
23
+ @st.cache_resource
24
+ def load_e5_embeddings():
25
+ model_name = "intfloat/multilingual-e5-large"
26
+ model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"}
27
+ encode_kwargs = {"normalize_embeddings": False}
28
+ embeddings = HuggingFaceEmbeddings(
29
+ model_name=model_name,
30
+ model_kwargs=model_kwargs,
31
+ encode_kwargs=encode_kwargs,
 
 
 
 
 
 
 
 
32
  )
33
+ return embeddings
34
+
35
+
36
+ @st.cache_resource
37
+ def load_rinna_model():
38
+ if torch.cuda.is_available():
39
+ model_name = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model_name,
43
+ load_in_8bit=True,
44
+ torch_dtype=torch.float16,
45
+ device_map="auto",
46
+ )
47
+ return tokenizer, model
48
+ else:
49
+ return None, None
50
+
51
+
52
+ E5_EMBEDDINGS = load_e5_embeddings()
53
+ RINNA_TOKENIZER, RINNA_MODEL = load_rinna_model()
54
 
55
 
56
  def _get_config_and_embeddings(collection_name: str | None) -> tuple:
 
65
  return db_config, embeddings
66
 
67
 
68
+ @st.cache_resource
69
+ def _get_rinna_llm(temperature: float) -> HuggingFacePipeline | None:
70
  if RINNA_MODEL is not None:
71
  pipe = pipeline(
72
  "text-generation",
 
107
  model_name: str | None,
108
  temperature: float,
109
  option: str | None,
110
+ ):
111
  db_config, embeddings = _get_config_and_embeddings(collection_name)
112
  db_url, db_api_key, db_collection_name = db_config
113
  client = QdrantClient(url=db_url, api_key=db_api_key)
 
137
  return result
138
 
139
 
140
+ def get_related_url(metadata) -> Iterable[str]:
141
  urls = set()
142
  for m in metadata:
143
  # p = m['source']
requirements.txt CHANGED
@@ -10,3 +10,4 @@ accelerate
10
  bitsandbytes
11
  scipy
12
  sentence_transformers
 
 
10
  bitsandbytes
11
  scipy
12
  sentence_transformers
13
+ streamlit