Futuresony commited on
Commit
40845a0
Β·
verified Β·
1 Parent(s): 49a3d08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -81
app.py CHANGED
@@ -1,91 +1,25 @@
1
- import gradio as gr
2
- import os
3
  import faiss
4
- import json
5
- import torch
6
  import numpy as np
7
- from huggingface_hub import InferenceClient, hf_hub_download
8
  from sentence_transformers import SentenceTransformer
9
 
10
- # Hugging Face Credentials
11
- HF_REPO = "Futuresony/future_ai_12_10_2024.gguf" # Your model repo
12
- FAISS_REPO = "Futuresony/future_faiss_index" # FAISS repo
13
- HF_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') # Use your token
14
-
15
- # Load Chat Model
16
- client = InferenceClient(
17
- model=HF_REPO,
18
- token=HF_TOKEN
19
- )
20
-
21
- # Load Sentence Transformer Model for FAISS
22
- embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
23
-
24
- # Load FAISS Index from Hugging Face
25
- FAISS_PATH = hf_hub_download(repo_id=FAISS_REPO, filename="asa_faiss.index", repo_type="model", token=HF_TOKEN)
26
  faiss_index = faiss.read_index(FAISS_PATH)
27
 
28
- # Load FAISS Text Data
29
- TEXT_DATA_PATH = hf_hub_download(repo_id=FAISS_REPO, filename="asa_text_data.npy", repo_type="model", token=HF_TOKEN)
30
- text_data = np.load(TEXT_DATA_PATH, allow_pickle=True)
31
-
32
- def retrieve_faiss_knowledge(user_query, top_k=3):
33
- """Retrieve the most relevant FAISS knowledge based on user input."""
34
- query_embedding = embedder.encode([user_query], convert_to_tensor=True).cpu().numpy()
35
- distances, indices = faiss_index.search(query_embedding, top_k)
36
-
37
- retrieved_texts = []
38
- print("\nπŸ” DEBUG: FAISS Retrieved Indices and Distances")
39
- print(indices, distances) # πŸ”₯ Check if FAISS is retrieving valid results
40
-
41
- for idx in indices[0]: # Extract top_k results
42
- if idx != -1: # Ensure valid index
43
- retrieved_texts.append(text_data[idx]) # βœ… Retrieve actual stored FAISS text!
44
-
45
- return "\n".join(retrieved_texts) if retrieved_texts else "**No relevant FAISS data found.**"
46
-
47
- def format_alpaca_prompt(user_input, system_prompt, history, faiss_knowledge=""):
48
- """Formats input in Alpaca/LLaMA style"""
49
- history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])
50
- faiss_context = f"\n### Retrieved Knowledge:\n{faiss_knowledge}" if faiss_knowledge else ""
51
-
52
- prompt = f"""{system_prompt}
53
- {history_str}
54
-
55
- ### Instruction:
56
- {user_input}
57
- {faiss_context}
58
-
59
- ### Response:
60
- """
61
- return prompt
62
-
63
- def respond(message, history, system_message, max_tokens, temperature, top_p):
64
- faiss_knowledge = retrieve_faiss_knowledge(message, top_k=3) # βœ… Get FAISS data
65
- formatted_prompt = format_alpaca_prompt(message, system_message, history, faiss_knowledge)
66
 
67
- response = client.text_generation(
68
- formatted_prompt,
69
- max_new_tokens=max_tokens,
70
- temperature=temperature,
71
- top_p=top_p,
72
- )
73
 
74
- cleaned_response = response.split("### Response:")[-1].strip()
75
-
76
- history.append((message, cleaned_response)) # βœ… Update history with new response
77
-
78
- yield cleaned_response # βœ… Return chatbot's answer
79
 
80
- demo = gr.ChatInterface(
81
- respond,
82
- additional_inputs=[
83
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
84
- gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
85
- gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
86
- gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
87
- ],
88
- )
89
 
90
- if __name__ == "__main__":
91
- demo.launch()
 
 
 
1
  import faiss
 
 
2
  import numpy as np
 
3
  from sentence_transformers import SentenceTransformer
4
 
5
+ # πŸ”Ή Load FAISS Index
6
+ FAISS_PATH = "asa_faiss.index"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  faiss_index = faiss.read_index(FAISS_PATH)
8
 
9
+ # πŸ”Ή Load Embedding Model
10
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # πŸ”Ή Test Query
13
+ test_query = "Where is ASA Microfinance located?"
14
+ query_embedding = embedder.encode([test_query], convert_to_tensor=True).cpu().numpy()
 
 
 
15
 
16
+ # πŸ”Ή Search FAISS
17
+ distances, indices = faiss_index.search(query_embedding, 3) # Retrieve top 3 matches
 
 
 
18
 
19
+ # πŸ”Ή Print Results
20
+ print("πŸ”Ή FAISS Search Results:")
21
+ for idx in indices[0]:
22
+ print(f"Index: {idx}")
 
 
 
 
 
23
 
24
+ print("πŸ”Ή FAISS Distances:")
25
+ print(distances)