JaiSurya commited on
Commit
431117f
1 Parent(s): 66d373d

Upload 3 files

Browse files

Basic functionalities

Files changed (3) hide show
  1. app.py +19 -57
  2. embeddings.csv +0 -0
  3. rag.py +97 -0
app.py CHANGED
@@ -1,58 +1,20 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- import spaces
4
-
5
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
6
-
7
- @spaces.GPU
8
- def respond(
9
- message,
10
- history: list[tuple[str, str]],
11
- system_message,
12
- max_tokens,
13
- temperature,
14
- top_p,
15
- ):
16
- messages = [{"role": "system", "content": system_message}]
17
-
18
- for val in history:
19
- if val[0]:
20
- messages.append({"role": "user", "content": val[0]})
21
- if val[1]:
22
- messages.append({"role": "assistant", "content": val[1]})
23
-
24
- messages.append({"role": "user", "content": message})
25
-
26
- response = ""
27
-
28
- for message in client.chat_completion(
29
- messages,
30
- max_tokens=max_tokens,
31
- stream=True,
32
- temperature=temperature,
33
- top_p=top_p,
34
- ):
35
- token = message.choices[0].delta.content
36
-
37
- response += token
38
- yield response
39
-
40
- demo = gr.ChatInterface(
41
- respond,
42
- additional_inputs=[
43
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
44
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
45
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
46
- gr.Slider(
47
- minimum=0.1,
48
- maximum=1.0,
49
- value=0.95,
50
- step=0.05,
51
- label="Top-p (nucleus sampling)",
52
- ),
53
- ],
54
- )
55
-
56
-
57
- if __name__ == "__main__":
58
  demo.launch()
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import spaces
4
+ from rag import RAG
5
+
6
+ r = RAG()
7
+
8
+ @spaces.GPU
9
+ def respond(text,history):
10
+ return r.query(text)
11
+
12
+ demo = gr.ChatInterface(
13
+ respond,
14
+ title="LAW LM",
15
+ description="Ask legal questions",
16
+ chatbot=gr.Chatbot(placeholder="Type your text here...")
17
+ )
18
+
19
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  demo.launch()
embeddings.csv ADDED
The diff for this file is too large to render. See raw diff
 
rag.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pandas as pd
4
+ from sentence_transformers import SentenceTransformer,util
5
+ from transformers import AutoTokenizer , AutoModelForCausalLM
6
+
7
+
8
+ class RAG:
9
+
10
+ def __init__(self):
11
+ self.model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ local_dir = "llm_models/"
14
+
15
+ self.embedding_model_name = "all-mpnet-base-v2"
16
+ self.embeddings_filename = "data/embeddings.csv"
17
+
18
+ self.data_pd = pd.read_csv(self.embeddings_filename)
19
+ self.data_dict = pd.read_csv(self.embeddings_filename).to_dict(orient='records')
20
+
21
+ self.data_embeddings = self.get_embeddings()
22
+
23
+ # Embedding model
24
+ self.embedding_model = SentenceTransformer(model_name_or_path = self.embedding_model_name,device = self.device)
25
+ # Tokenizer
26
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_id,
27
+ cache_dir = local_dir)
28
+ # LLM
29
+ self.llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=self.model_id,
30
+ cache_dir = local_dir).to(self.device)
31
+
32
+ def get_embeddings(self) -> list:
33
+ """Returns the embeddings from the csv file"""
34
+ data_embeddings = []
35
+
36
+ for tensor_str in self.data_pd["embeddings"]:
37
+ values_str = tensor_str.split("[")[1].split("]")[0]
38
+ values_list = [float(val) for val in values_str.split(",")]
39
+ tensor_result = torch.tensor(values_list)
40
+ data_embeddings.append(tensor_result)
41
+
42
+ data_embeddings = torch.stack(data_embeddings).to(self.device)
43
+ return data_embeddings
44
+
45
+
46
+ def retrieve_relevant_resource(self,user_query : str , k = 5):
47
+ """Function to retrieve relevant resource"""
48
+ query_embedding = self.embedding_model.encode(user_query, convert_to_tensor = True).to(self.device)
49
+ dot_score = util.dot_score( a = query_embedding, b = self.data_embeddings)[0]
50
+ score , idx = torch.topk(dot_score,k=k)
51
+ return score,idx
52
+
53
+ def prompt_formatter(self,query: str, context_items: list[dict]) -> str:
54
+ """
55
+ Augments query with text-based context from context_items.
56
+ """
57
+ # Join context items into one dotted paragraph
58
+ context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
59
+
60
+ base_prompt = """You are a friendly lawyer chatbot who always responds in the style of a judge
61
+ Based on the following context items, please answer the query.
62
+ \nNow use the following context items to answer the user query:
63
+ {context}
64
+ \nRelevant passages: <extract relevant passages from the context here>"""
65
+
66
+ # Update base prompt with context items and query
67
+ base_prompt = base_prompt.format(context=context)
68
+
69
+ # Create prompt template for instruction-tuned model
70
+ dialogue_template = [
71
+ {
72
+ "role" : "system",
73
+ "content" : base_prompt,
74
+ },
75
+ {
76
+ "role": "user",
77
+ "content": query,
78
+ },
79
+ ]
80
+
81
+ # Apply the chat template
82
+ prompt = self.tokenizer.apply_chat_template(conversation=dialogue_template,
83
+ tokenize=False,
84
+ add_generation_prompt=True)
85
+ return prompt
86
+
87
+ def query(self,user_text : str):
88
+ scores, indices = self.retrieve_relevant_resource(user_text)
89
+ context_items = [self.data_dict[i] for i in indices]
90
+ prompt = self.prompt_formatter(query=user_text,context_items=context_items)
91
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
92
+ outputs = self.llm_model.generate(**input_ids,max_new_tokens=512)
93
+ output_text = self.tokenizer.decode(outputs[0])
94
+ output_text = output_text.split("<|assistant|>")
95
+ output_text = output_text[1].split("</s>")[0]
96
+
97
+ return output_text