File size: 4,004 Bytes
431117f
 
 
 
 
 
 
 
 
 
6b7842c
431117f
 
 
83a1281
431117f
 
 
 
 
 
 
 
 
83a1281
431117f
850a0c9
431117f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer,util
from transformers import AutoTokenizer , AutoModelForCausalLM


class RAG:

    def __init__(self):
        self.model_id = "microsoft/Phi-3-mini-128k-instruct"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.embedding_model_name = "all-mpnet-base-v2"
        self.embeddings_filename = "embeddings.csv"
        
        self.data_pd = pd.read_csv(self.embeddings_filename)
        self.data_dict = pd.read_csv(self.embeddings_filename).to_dict(orient='records')
        
        self.data_embeddings = self.get_embeddings()

        # Embedding model
        self.embedding_model = SentenceTransformer(model_name_or_path = self.embedding_model_name,device = self.device)
        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_id)
        # LLM
        self.llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=self.model_id,trust_remote_code=True).to(self.device)

    def get_embeddings(self) -> list:
        """Returns the embeddings from the csv file"""
        data_embeddings = []
    
        for tensor_str in self.data_pd["embeddings"]:
            values_str = tensor_str.split("[")[1].split("]")[0]
            values_list = [float(val) for val in values_str.split(",")]
            tensor_result = torch.tensor(values_list)
            data_embeddings.append(tensor_result)
    
        data_embeddings = torch.stack(data_embeddings).to(self.device)
        return data_embeddings


    def retrieve_relevant_resource(self,user_query : str , k = 5):
        """Function to retrieve relevant resource"""
        query_embedding = self.embedding_model.encode(user_query, convert_to_tensor = True).to(self.device)
        dot_score = util.dot_score( a = query_embedding, b = self.data_embeddings)[0]
        score , idx = torch.topk(dot_score,k=k)
        return score,idx

    def prompt_formatter(self,query: str, context_items: list[dict]) -> str:
        """
        Augments query with text-based context from context_items.
        """
        # Join context items into one dotted paragraph
        context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
    
        base_prompt = """You are a friendly lawyer chatbot who always responds in the style of a judge
        Based on the following context items, please answer the query.
        \nNow use the following context items to answer the user query:
        {context}
        \nRelevant passages: <extract relevant passages from the context here>"""
    
        # Update base prompt with context items and query   
        base_prompt = base_prompt.format(context=context)
        
        # Create prompt template for instruction-tuned model
        dialogue_template = [
            {
                "role" : "system",
                "content" : base_prompt,
            },
            {
                "role": "user",
                "content": query,
            },
        ]
    
        # Apply the chat template
        prompt = self.tokenizer.apply_chat_template(conversation=dialogue_template,
                                              tokenize=False,
                                              add_generation_prompt=True)
        return prompt

    def query(self,user_text : str):
        scores, indices = self.retrieve_relevant_resource(user_text)
        context_items = [self.data_dict[i] for i in indices]
        prompt = self.prompt_formatter(query=user_text,context_items=context_items)
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        outputs = self.llm_model.generate(**input_ids,max_new_tokens=512) 
        output_text = self.tokenizer.decode(outputs[0])
        output_text = output_text.split("<|assistant|>")
        output_text = output_text[1].split("</s>")[0]
        
        return output_text