Spaces:
Paused
Paused
dev/streamlit-cache (#13)
Browse files- added streamlit cache (60434a827f1979f49eea9913f3bfc0c6f58a2957)
- app.py +34 -22
- requirements.txt +1 -0
app.py
CHANGED
@@ -20,26 +20,37 @@ from qdrant_client import QdrantClient
|
|
20 |
from config import DB_CONFIG, DB_E5_CONFIG
|
21 |
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
if False and torch.cuda.is_available(): # TODO: for local debug
|
33 |
-
RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
|
34 |
-
RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
|
35 |
-
RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
|
36 |
-
RINNA_MODEL_NAME,
|
37 |
-
load_in_8bit=True,
|
38 |
-
torch_dtype=torch.float16,
|
39 |
-
device_map="auto",
|
40 |
)
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
def _get_config_and_embeddings(collection_name: str | None) -> tuple:
|
@@ -54,7 +65,8 @@ def _get_config_and_embeddings(collection_name: str | None) -> tuple:
|
|
54 |
return db_config, embeddings
|
55 |
|
56 |
|
57 |
-
|
|
|
58 |
if RINNA_MODEL is not None:
|
59 |
pipe = pipeline(
|
60 |
"text-generation",
|
@@ -95,7 +107,7 @@ def get_retrieval_qa(
|
|
95 |
model_name: str | None,
|
96 |
temperature: float,
|
97 |
option: str | None,
|
98 |
-
)
|
99 |
db_config, embeddings = _get_config_and_embeddings(collection_name)
|
100 |
db_url, db_api_key, db_collection_name = db_config
|
101 |
client = QdrantClient(url=db_url, api_key=db_api_key)
|
@@ -125,7 +137,7 @@ def get_retrieval_qa(
|
|
125 |
return result
|
126 |
|
127 |
|
128 |
-
def get_related_url(metadata):
|
129 |
urls = set()
|
130 |
for m in metadata:
|
131 |
# p = m['source']
|
|
|
20 |
from config import DB_CONFIG, DB_E5_CONFIG
|
21 |
|
22 |
|
23 |
+
@st.cache_resource
|
24 |
+
def load_e5_embeddings():
|
25 |
+
model_name = "intfloat/multilingual-e5-large"
|
26 |
+
model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"}
|
27 |
+
encode_kwargs = {"normalize_embeddings": False}
|
28 |
+
embeddings = HuggingFaceEmbeddings(
|
29 |
+
model_name=model_name,
|
30 |
+
model_kwargs=model_kwargs,
|
31 |
+
encode_kwargs=encode_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
)
|
33 |
+
return embeddings
|
34 |
+
|
35 |
+
|
36 |
+
@st.cache_resource
|
37 |
+
def load_rinna_model():
|
38 |
+
if torch.cuda.is_available():
|
39 |
+
model_name = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
41 |
+
model = AutoModelForCausalLM.from_pretrained(
|
42 |
+
model_name,
|
43 |
+
load_in_8bit=True,
|
44 |
+
torch_dtype=torch.float16,
|
45 |
+
device_map="auto",
|
46 |
+
)
|
47 |
+
return tokenizer, model
|
48 |
+
else:
|
49 |
+
return None, None
|
50 |
+
|
51 |
+
|
52 |
+
E5_EMBEDDINGS = load_e5_embeddings()
|
53 |
+
RINNA_TOKENIZER, RINNA_MODEL = load_rinna_model()
|
54 |
|
55 |
|
56 |
def _get_config_and_embeddings(collection_name: str | None) -> tuple:
|
|
|
65 |
return db_config, embeddings
|
66 |
|
67 |
|
68 |
+
@st.cache_resource
|
69 |
+
def _get_rinna_llm(temperature: float) -> HuggingFacePipeline | None:
|
70 |
if RINNA_MODEL is not None:
|
71 |
pipe = pipeline(
|
72 |
"text-generation",
|
|
|
107 |
model_name: str | None,
|
108 |
temperature: float,
|
109 |
option: str | None,
|
110 |
+
):
|
111 |
db_config, embeddings = _get_config_and_embeddings(collection_name)
|
112 |
db_url, db_api_key, db_collection_name = db_config
|
113 |
client = QdrantClient(url=db_url, api_key=db_api_key)
|
|
|
137 |
return result
|
138 |
|
139 |
|
140 |
+
def get_related_url(metadata) -> Iterable[str]:
|
141 |
urls = set()
|
142 |
for m in metadata:
|
143 |
# p = m['source']
|
requirements.txt
CHANGED
@@ -10,3 +10,4 @@ accelerate
|
|
10 |
bitsandbytes
|
11 |
scipy
|
12 |
sentence_transformers
|
|
|
|
10 |
bitsandbytes
|
11 |
scipy
|
12 |
sentence_transformers
|
13 |
+
streamlit
|