nishantgaurav23 commited on
Commit
a88fc03
1 Parent(s): 71f7801

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -14,6 +14,11 @@ import sys
14
  from llama_cpp import Llama
15
  from tqdm import tqdm
16
 
 
 
 
 
 
17
  # Set page config first
18
  st.set_page_config(
19
  page_title="The Sport Chatbot",
@@ -27,7 +32,28 @@ logging.basicConfig(
27
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
28
  handlers=[logging.StreamHandler(sys.stdout)]
29
  )
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  def download_file_with_progress(url: str, filename: str):
32
  """Download a file with progress bar using requests"""
33
  response = requests.get(url, stream=True)
@@ -156,7 +182,8 @@ class RAGPipeline:
156
  self.retriever = SentenceTransformerRetriever()
157
  self.documents = []
158
  self.device = torch.device("cpu")
159
- self.llm = load_llama_model()
 
160
 
161
  def preprocess_query(self, query: str) -> str:
162
  """Clean and prepare the query"""
 
14
  from llama_cpp import Llama
15
  from tqdm import tqdm
16
 
17
+ # At the top of your script
18
+ os.environ['LLAMA_CPP_THREADS'] = '4'
19
+ os.environ['LLAMA_CPP_BATCH_SIZE'] = '512'
20
+ os.environ['LLAMA_CPP_MODEL_PATH'] = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf")
21
+
22
  # Set page config first
23
  st.set_page_config(
24
  page_title="The Sport Chatbot",
 
32
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
33
  handlers=[logging.StreamHandler(sys.stdout)]
34
  )
35
+ # Add this at the top level of your script, after imports
36
+ @st.cache_resource
37
+ def get_llama_model():
38
+ model_path = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf")
39
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
40
+
41
+ if not os.path.exists(model_path):
42
+ st.info("Downloading model... This may take a while.")
43
+ direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
44
+ download_file_with_progress(direct_url, model_path)
45
 
46
+ llm_config = {
47
+ "model_path": model_path,
48
+ "n_ctx": 2048,
49
+ "n_threads": 4,
50
+ "n_batch": 512,
51
+ "n_gpu_layers": 0,
52
+ "verbose": False,
53
+ "use_mlock": True
54
+ }
55
+
56
+ return Llama(**llm_config)
57
  def download_file_with_progress(url: str, filename: str):
58
  """Download a file with progress bar using requests"""
59
  response = requests.get(url, stream=True)
 
182
  self.retriever = SentenceTransformerRetriever()
183
  self.documents = []
184
  self.device = torch.device("cpu")
185
+ # Use the cached model directly
186
+ self.llm = get_llama_model()
187
 
188
  def preprocess_query(self, query: str) -> str:
189
  """Clean and prepare the query"""