Spaces:
Runtime error
Runtime error
nishantgaurav23
commited on
Commit
•
a88fc03
1
Parent(s):
71f7801
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,11 @@ import sys
|
|
14 |
from llama_cpp import Llama
|
15 |
from tqdm import tqdm
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
# Set page config first
|
18 |
st.set_page_config(
|
19 |
page_title="The Sport Chatbot",
|
@@ -27,7 +32,28 @@ logging.basicConfig(
|
|
27 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
28 |
handlers=[logging.StreamHandler(sys.stdout)]
|
29 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def download_file_with_progress(url: str, filename: str):
|
32 |
"""Download a file with progress bar using requests"""
|
33 |
response = requests.get(url, stream=True)
|
@@ -156,7 +182,8 @@ class RAGPipeline:
|
|
156 |
self.retriever = SentenceTransformerRetriever()
|
157 |
self.documents = []
|
158 |
self.device = torch.device("cpu")
|
159 |
-
|
|
|
160 |
|
161 |
def preprocess_query(self, query: str) -> str:
|
162 |
"""Clean and prepare the query"""
|
|
|
14 |
from llama_cpp import Llama
|
15 |
from tqdm import tqdm
|
16 |
|
17 |
+
# At the top of your script
|
18 |
+
os.environ['LLAMA_CPP_THREADS'] = '4'
|
19 |
+
os.environ['LLAMA_CPP_BATCH_SIZE'] = '512'
|
20 |
+
os.environ['LLAMA_CPP_MODEL_PATH'] = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf")
|
21 |
+
|
22 |
# Set page config first
|
23 |
st.set_page_config(
|
24 |
page_title="The Sport Chatbot",
|
|
|
32 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
33 |
handlers=[logging.StreamHandler(sys.stdout)]
|
34 |
)
|
35 |
+
# Add this at the top level of your script, after imports
|
36 |
+
@st.cache_resource
|
37 |
+
def get_llama_model():
|
38 |
+
model_path = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf")
|
39 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
40 |
+
|
41 |
+
if not os.path.exists(model_path):
|
42 |
+
st.info("Downloading model... This may take a while.")
|
43 |
+
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
|
44 |
+
download_file_with_progress(direct_url, model_path)
|
45 |
|
46 |
+
llm_config = {
|
47 |
+
"model_path": model_path,
|
48 |
+
"n_ctx": 2048,
|
49 |
+
"n_threads": 4,
|
50 |
+
"n_batch": 512,
|
51 |
+
"n_gpu_layers": 0,
|
52 |
+
"verbose": False,
|
53 |
+
"use_mlock": True
|
54 |
+
}
|
55 |
+
|
56 |
+
return Llama(**llm_config)
|
57 |
def download_file_with_progress(url: str, filename: str):
|
58 |
"""Download a file with progress bar using requests"""
|
59 |
response = requests.get(url, stream=True)
|
|
|
182 |
self.retriever = SentenceTransformerRetriever()
|
183 |
self.documents = []
|
184 |
self.device = torch.device("cpu")
|
185 |
+
# Use the cached model directly
|
186 |
+
self.llm = get_llama_model()
|
187 |
|
188 |
def preprocess_query(self, query: str) -> str:
|
189 |
"""Clean and prepare the query"""
|