sofiia19 commited on
Commit
b967ca3
1 Parent(s): 7d0112b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from litellm import completion
3
+ import os
4
+
5
+ os.environ['GROQ_API_KEY'] = "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"
6
+ response = completion(
7
+ model="groq/llama3-8b-8192",
8
+ messages=[
9
+ {"role": "user", "content": "hello from litellm"}
10
+ ],
11
+ )
12
+ from datasets import load_dataset
13
+
14
+ dataset = load_dataset("hugginglearners/russia-ukraine-conflict-articles")
15
+
16
+
17
+ docs = [item['articles'] for item in dataset['train'].select(range(10))]
18
+ def chunk_document(doc: str, doc_id: int, desired_chunk_size: int = 100, max_chunk_size: int = 3000):
19
+ chunk = ''
20
+ chunk_number = 0
21
+ for line in doc.splitlines():
22
+ chunk += line + '\n'
23
+ if len(chunk) >= desired_chunk_size:
24
+ yield (doc_id, chunk_number, chunk[:max_chunk_size])
25
+ chunk = ''
26
+ chunk_number += 1
27
+ if chunk:
28
+ yield (doc_id, chunk_number, chunk)
29
+
30
+ def chunk_documents(docs: List[str], desired_chunk_size: int = 100, max_chunk_size: int = 3000):
31
+ chunks = []
32
+ for doc_id, doc in enumerate(docs):
33
+ chunks.extend(chunk_document(doc, doc_id, desired_chunk_size, max_chunk_size))
34
+ return chunks
35
+
36
+ from typing import List
37
+ import numpy as np
38
+ from rank_bm25 import BM25Okapi
39
+ from sentence_transformers import SentenceTransformer
40
+ import torch
41
+ class Retriever:
42
+ def __init__(self, docs: List[str]):
43
+
44
+ self.chunks = chunk_documents(docs)
45
+ self.docs = [chunk[2] for chunk in self.chunks]
46
+ tokenized_docs = [doc.lower().split(" ") for doc in self.docs]
47
+ self.bm25 = BM25Okapi(tokenized_docs)
48
+ self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
49
+ self.doc_embeddings = self.sbert.encode(self.docs)
50
+
51
+ def get_docs(self, query, method="bm25", n=3):
52
+ if method == "bm25":
53
+ scores = self._get_bm25_scores(query)
54
+ elif method == "sbert":
55
+ scores = self._get_semantic_scores(query)
56
+ elif method == "hybrid":
57
+ bm25_scores = self._get_bm25_scores(query)
58
+ semantic_scores = self._get_semantic_scores(query)
59
+ scores = 0.3 * bm25_scores + 0.7 * semantic_scores
60
+ else:
61
+ raise ValueError("Invalid method. Choose 'bm25', 'sbert', or 'hybrid'.")
62
+
63
+ sorted_indices = np.argsort(scores)[::-1]
64
+ # Повертаємо перші n документів із інформацією про джерело
65
+ return [(self.chunks[i][0], self.chunks[i][1], self.docs[i]) for i in sorted_indices[:n]]
66
+
67
+ def _get_bm25_scores(self, query):
68
+ tokenized_query = query.lower().split(" ")
69
+ return self.bm25.get_scores(tokenized_query)
70
+
71
+ def _get_semantic_scores(self, query):
72
+ query_embedding = self.sbert.encode(query)
73
+ scores = torch.cosine_similarity(
74
+ torch.tensor(query_embedding).unsqueeze(0),
75
+ torch.tensor(self.doc_embeddings),
76
+ dim=1
77
+ )
78
+ return scores.numpy()
79
+ class QuestionAnsweringBot:
80
+ PROMPT = '''\
81
+ You are a helpful assistant that can answer questions.
82
+
83
+ Rules:
84
+ -Reply with the answer only and nothing but the answer.
85
+ -Say 'I don't know(((' if you don't know the answer.
86
+ -Use the provided context.
87
+ '''
88
+
89
+ def __init__(self, docs):
90
+ self.retriever = Retriever(docs)
91
+
92
+ def answer_question(self, question: str, method: str = "bm25") -> str:
93
+ context_with_indices = self.retriever.get_docs(question, method=method)
94
+ if not context_with_indices:
95
+ return "I don't know((("
96
+
97
+ # контекст для моделі
98
+ context = "\n".join([f"Doc {doc_id}, Chunk {chunk_id}: {text}" for doc_id, chunk_id, text in context_with_indices])
99
+
100
+ messages = [
101
+ {"role": "system", "content": self.PROMPT},
102
+ {"role": "user", "content": f"Context: {context}\nQuestion: {question}"}
103
+ ]
104
+
105
+ try:
106
+
107
+ completionn = completion(
108
+ model="groq/llama3-8b-8192",
109
+ messages=messages,
110
+ )
111
+ # Відповідь
112
+ answer = completionn['choices'][0]['message']['content']
113
+
114
+ # джерела
115
+ sources = [f"Doc {doc_id}: Chunk {chunk_id}; " for doc_id, chunk_id, _ in context_with_indices]
116
+ return f"{answer} [{', '.join(sources)}]"
117
+ except Exception as e:
118
+ return f"Error: {str(e)}"
119
+
120
+
121
+ # question = "Tell about war"
122
+ docs = docs
123
+ # bot = QuestionAnsweringBot(docs)
124
+ # answer = bot.answer_question(question)
125
+
126
+ # print(f'Q: {question}')
127
+ # print(f'A: {answer}')
128
+ import gradio as gr
129
+
130
+ def answer_question_with_method(query, method):
131
+ bot = QuestionAnsweringBot(docs)
132
+ return bot.answer_question(query, method=method)
133
+
134
+
135
+ # Створення інтерфейсу
136
+ demo = gr.Interface(
137
+ fn=answer_question_with_method,
138
+ inputs=[
139
+ gr.Textbox(label="Your Question"),
140
+ gr.Dropdown(
141
+ choices=["bm25", "sbert", "hybrid"],
142
+ value="hybrid",
143
+ label="Select Retrieval Method"
144
+ )
145
+ ],
146
+ outputs="text"
147
+ )
148
+
149
+ demo.launch()
150
+