UnnamedUnknownx1234987789489 commited on
Commit
d355eed
1 Parent(s): 14fdc76

Create functions.py

Browse files
Files changed (1) hide show
  1. functions.py +820 -0
functions.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from typing_extensions import TypedDict, List
4
+ from IPython.display import Image, display
5
+ from langchain_core.pydantic_v1 import BaseModel, Field
6
+ from langchain.schema import Document
7
+ from langgraph.graph import START, END, StateGraph
8
+ from langchain.prompts import PromptTemplate
9
+ import uuid
10
+ from langchain_groq import ChatGroq
11
+ from langchain_community.utilities import GoogleSerperAPIWrapper
12
+ from langchain_chroma import Chroma
13
+ from langchain_community.document_loaders import NewsURLLoader
14
+ from langchain_community.retrievers.wikipedia import WikipediaRetriever
15
+ from sentence_transformers import SentenceTransformer
16
+ from langchain.vectorstores import Chroma
17
+ from langchain_community.document_loaders import UnstructuredURLLoader, NewsURLLoader
18
+ from langchain_community.embeddings import HuggingFaceEmbeddings
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+ from langchain_community.document_loaders import WebBaseLoader
21
+ from langchain_core.output_parsers import StrOutputParser
22
+ from langchain_core.output_parsers import JsonOutputParser
23
+ from langchain_community.vectorstores.utils import filter_complex_metadata
24
+ from langchain.schema import Document
25
+ from langchain_community.document_loaders.directory import DirectoryLoader
26
+ from langchain.document_loaders import TextLoader
27
+ from langgraph.graph import START, END, StateGraph
28
+ from langchain.retrievers import WebResearchRetriever
29
+ from langchain.callbacks.manager import CallbackManager
30
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
31
+ from exa_py import Exa
32
+
33
+
34
+
35
+ os.environ["LANGCHAIN_API_KEY"] = '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'
36
+ os.environ["LANGCHAIN_TRACING_V2"]="true"
37
+ os.environ["LANGCHAIN_ENDPOINT"]= "https://api.smith.langchain.com"
38
+ os.environ["LANGCHAIN_PROJECT"] = "Lithuanian_law_v2_LT_Kalba_Groq"
39
+ os.environ["GROQ_API_KEY"] = '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'
40
+ os.environ["SERPER_API_KEY"] = '6f80701ecd004c2466e8bd7bcebacacf89c74b84'
41
+ exa = Exa(api_key="6ecb4e80-83e8-47c4-a116-c1041d0e096e")
42
+
43
+
44
+
45
+
46
+
47
+
48
+ def create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type='mmr', k=7, chunk_size=300, chunk_overlap=30):
49
+
50
+ model_name = "Alibaba-NLP/gte-multilingual-base"
51
+ model_kwargs = {'device': 'cpu',
52
+ "trust_remote_code" : 'False'}
53
+ encode_kwargs = {'normalize_embeddings': True}
54
+ embeddings = HuggingFaceEmbeddings(
55
+ model_name=model_name,
56
+ model_kwargs=model_kwargs,
57
+ encode_kwargs=encode_kwargs
58
+ )
59
+
60
+
61
+
62
+ if os.path.exists(vectorstore_path) and os.listdir(vectorstore_path):
63
+ vectorstore = Chroma(persist_directory=vectorstore_path,embedding_function=embeddings)
64
+
65
+
66
+ else:
67
+ st.write("Vector store doesnt exist and will be created now")
68
+ loader = DirectoryLoader('./data/', glob="./*.txt", loader_cls=TextLoader)
69
+ docs = loader.load()
70
+
71
+
72
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
73
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap,
74
+ separators=["\n\n \n\n","\n\n\n", "\n\n", r"In \[[0-9]+\]", r"\n+", r"\s+"],
75
+ is_separator_regex = True
76
+ )
77
+ split_docs = text_splitter.split_documents(docs)
78
+
79
+
80
+ vectorstore = Chroma.from_documents(
81
+ documents=split_docs, embedding=embeddings, persist_directory=vectorstore_path,
82
+ )
83
+
84
+
85
+ retriever=vectorstore.as_retriever(search_type = search_type, search_kwargs={"k": k})
86
+
87
+ return retriever
88
+
89
+
90
+ def handle_userinput(user_question, custom_graph):
91
+ # Add the user's question to the chat history and display it in the UI
92
+ st.session_state.messages.append({"role": "user", "content": user_question})
93
+ st.chat_message("user").write(user_question)
94
+
95
+ # Generate a unique thread ID for the graph's state
96
+ config = {"configurable": {"thread_id": str(uuid.uuid4())}}
97
+
98
+ try:
99
+ # Invoke the custom graph with the input question
100
+ state_dict = custom_graph.invoke(
101
+ {"question": user_question, "steps": []}, config
102
+ )
103
+
104
+ docs = state_dict["documents"]
105
+ with st.sidebar:
106
+ st.subheader("Dokumentai, kuriuos Birutė gavo kaip kontekstą")
107
+ with st.spinner("Processing"):
108
+ for doc in docs:
109
+ # Extract document content
110
+ content = doc
111
+
112
+ # Extract document metadata if available
113
+ #metadata =doc.metadata.get('original_doc_name', 'unknown')
114
+ # Display content and metadata
115
+ st.write(f"Documentas: {content}")
116
+
117
+
118
+
119
+
120
+ # Check if a response (generation) was produced by the graph
121
+ if 'generation' in state_dict and state_dict['generation']:
122
+ response = state_dict["generation"]
123
+
124
+ # Add the assistant's response to the chat history and display it
125
+ st.session_state.messages.append({"role": "assistant", "content": response})
126
+ st.chat_message("assistant").write(response)
127
+ else:
128
+ st.chat_message("assistant").write("Your question violates toxicity rules or contains sensitive information.")
129
+
130
+ except Exception as e:
131
+ # Display an error message in case of failure
132
+ st.chat_message("assistant").write("Klaida: Arba per didelis kontekstas suteiktas modeliui, arba užklausų serveryje yra per daug")
133
+
134
+
135
+
136
+
137
+
138
+ from typing import Annotated
139
+
140
+ def create_workflow(retriever):
141
+ class GraphState(TypedDict):
142
+ """
143
+ Represents the state of our graph.
144
+ Attributes:
145
+ question: question
146
+ generation: LLM generation
147
+ search: whether to add search
148
+ documents: list of documents
149
+ generations_count : generations count
150
+ """
151
+ question: Annotated[str, "Single"] # Ensuring only one value per step
152
+ generation: str
153
+ search: str
154
+ documents: List[str]
155
+ steps: List[str]
156
+ generation_count: int
157
+
158
+
159
+ llm = ChatGroq(
160
+ model="llama-3.3-70b-versatile",
161
+ temperature=0.2,
162
+ max_tokens=600,
163
+ max_retries=3,
164
+
165
+ )
166
+ llm_checker = ChatGroq(
167
+ model="llama3-groq-70b-8192-tool-use-preview",
168
+ temperature=0.1,
169
+ max_tokens=400,
170
+ max_retries=3,
171
+ )
172
+
173
+
174
+
175
+
176
+
177
+
178
+
179
+
180
+ workflow = StateGraph(GraphState)
181
+
182
+ # Define the nodes
183
+ workflow.add_node("ask_question", lambda state: ask_question(state))
184
+ workflow.add_node("retrieve", lambda state: retrieve(state, retriever))
185
+ workflow.add_node("grade_documents", lambda state: grade_documents(state, retrieval_grader_grader(llm_checker)))
186
+ workflow.add_node("generate", lambda state: generate(state, QA_chain(llm)))
187
+ workflow.add_node("web_search", web_search)
188
+ #workflow.add_node("transform_query", lambda state: transform_query(state, create_question_rewriter(llm)))
189
+
190
+ # Build graph
191
+ workflow.set_entry_point("ask_question")
192
+ workflow.add_edge("ask_question", "retrieve")
193
+ workflow.add_edge("retrieve", "grade_documents")
194
+
195
+ #workflow.add_edge("retrieve", "generate")
196
+
197
+
198
+
199
+ workflow.add_conditional_edges(
200
+ "grade_documents",
201
+ decide_to_generate,
202
+ {
203
+ "search": "web_search",
204
+ "generate": "generate",
205
+
206
+ },
207
+ )
208
+
209
+
210
+
211
+
212
+ workflow.add_edge("web_search", "generate")
213
+ workflow.add_edge("generate", END)
214
+
215
+
216
+
217
+
218
+
219
+
220
+ custom_graph = workflow.compile()
221
+
222
+ return custom_graph
223
+
224
+ def retrieval_grader_grader(llm):
225
+ """
226
+ Function to create a grader object using a passed LLM model.
227
+
228
+ Args:
229
+ llm: The language model to be used for grading.
230
+
231
+ Returns:
232
+ Callable: A pipeline function that grades relevance based on the LLM.
233
+ """
234
+ class GradeDocuments(BaseModel):
235
+ """Ar faktas gali būti, nors truputi, naudingas atsakant į klausimą."""
236
+ binary_score: str = Field(
237
+ description="Documentai yra aktualūs klausimui, 'yes' arba 'no'"
238
+ )
239
+
240
+ # Create the structured LLM grader using the passed LLM
241
+ structured_llm_grader = llm.with_structured_output(GradeDocuments)
242
+
243
+
244
+
245
+
246
+ # Define the prompt template
247
+ prompt = PromptTemplate(
248
+ template="""Jūs esate mokytojas, vertinantis viktoriną. Jums bus suteikta:
249
+ 1/ KLAUSIMAS {question}
250
+ 2/ Studento pateiktas FAKTAS {documents}
251
+
252
+ Jūs vertinate RELEVANCE RECALL:
253
+ yes reiškia, kad FAKTAS yra susijęs su KLAUSIMU.
254
+ no reiškia, kad FAKTAS nesusijęs su KLAUSIMU.
255
+ yes yra aukščiausias (geriausias) balas. no yra žemiausias balas, kurį galite duoti.
256
+ Jeigu galima iš Studento pateiktas FAKTAS gauti bet kokių įžvalgu susijusiu su KLAUSIMAS, duok įvertinimą yes.
257
+
258
+ Žingsnis po žingsnio paaiškinkite savo samprotavimus. Įsitikinkite, kad jūsų samprotavimai ir išvados yra teisingi.
259
+
260
+ Iš pradžių venkite tiesiog nurodyti teisingą atsakymą.
261
+
262
+ Klausimas: {question} \n
263
+ FAKTAS: \n\n {documents} \n\n
264
+
265
+ Suteikite dvejetainį balą „yes“ arba „no“, kad nurodytumėte, ar dokumentas yra susijęs su klausimu. \n
266
+ Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo.
267
+ """,
268
+ input_variables=['documents', 'question'],
269
+ )
270
+
271
+ # Combine the prompt with the structured LLM grader
272
+ retrieval_grader = prompt | structured_llm_grader
273
+
274
+ # Return the grader object
275
+ return retrieval_grader
276
+
277
+
278
+ def create_toxicity_checker(llm):
279
+ """
280
+ Function to create a toxicity checker object using a passed LLM model.
281
+
282
+ Args:
283
+ llm: The language model to be used for toxicity checking.
284
+
285
+ Returns:
286
+ Callable: A pipeline function that checks for toxicity based on the LLM.
287
+ """
288
+
289
+ # Define the class for toxicity check inside the function
290
+ class PromtGuard(BaseModel):
291
+ """Binary score for toxicity check on question."""
292
+ score: str = Field(
293
+ description="Ar klausimams yra nesaugus, 'taip' arba 'ne'"
294
+ )
295
+
296
+ # Create the structured LLM toxicity checker using the passed LLM
297
+
298
+
299
+ # Create the structured LLM grader using the passed LLM
300
+ structured_llm_toxicity_checker = llm.with_structured_output(PromtGuard)
301
+
302
+
303
+ # Define the prompt template
304
+ prompt = PromptTemplate(
305
+ template="""Jūs esate klasifikatorius, kuris įvertins tekstą pagal pateiktus kriterijus.
306
+ Čia pateikiami taškų kriterijai:
307
+
308
+ (1) įžeidinėjimai
309
+
310
+ (2) grasinimai
311
+
312
+ (3) labai neigiami komentarai
313
+
314
+ (4) bet kokią asmenį identifikuojančią informaciją, pvz., API raktus
315
+
316
+ Rezultatas:
317
+
318
+ „Taip“ reiškia, kad tekstas atitinka BET KOKĮ kriterijų.
319
+
320
+ Įvertinimas „ne“ reiškia, kad tekstas neatitinka VISŲ kriterijų.
321
+
322
+ Štai klausimas: {question}
323
+
324
+ Suteikite dvejetainį balą „taip“ arba „ne“, kad nurodytumėte, ar atsakymas yra naudingas norint išspręsti klausimą.
325
+ Jei atsakyme yra pasikartojančių frazių, kartojimas, tada grąžinkite „ne“\n
326
+ Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo.""",
327
+ input_variables=["question"],
328
+ )
329
+
330
+ # Combine the prompt with the structured LLM toxicity checker
331
+ toxicity_grader = prompt | structured_llm_toxicity_checker
332
+
333
+ # Return the toxicity checker object
334
+ return toxicity_grader
335
+
336
+
337
+ def grade_question_toxicity(state, toxicity_grader):
338
+ """
339
+ Grades the question for toxicity.
340
+
341
+ Args:
342
+ state (dict): The current graph state.
343
+
344
+ Returns:
345
+ str: 'good' if the question passes the toxicity check, 'bad' otherwise.
346
+ """
347
+ steps = state["steps"]
348
+ steps.append("promt guard")
349
+ score = toxicity_grader.invoke({"question": state["question"]})
350
+ grade = getattr(score, 'score', None)
351
+
352
+ if grade == "yes":
353
+ return "bad"
354
+ else:
355
+ return "good"
356
+
357
+
358
+
359
+ def create_helpfulness_checker(llm):
360
+ """
361
+ Function to create a helpfulness checker object using a passed LLM model.
362
+
363
+ Args:
364
+ llm: The language model to be used for checking the helpfulness of answers.
365
+
366
+ Returns:
367
+ Callable: A pipeline function that checks if the student's answer is helpful.
368
+ """
369
+
370
+ class helpfulness_checker(BaseModel):
371
+ """Binary score for toxicity check on question."""
372
+ score: str = Field(
373
+ description="Ar atsakymas yra naudingas?, 'taip' arba 'ne'"
374
+ )
375
+
376
+ # Create the structured LLM toxicity checker using the passed LLM
377
+
378
+
379
+
380
+ structured_llm_helpfulness_checker = llm.with_structured_output(helpfulness_checker)
381
+
382
+
383
+ # Create the structured LLM helpfulness checker using the passed LLM
384
+
385
+ # Define the prompt template
386
+ prompt = PromptTemplate(
387
+ template="""Jums bus pateiktas KLAUSIMAS {question} ir ATSAKYMAS {generation}.
388
+ Įvertinkite ATSAKYMĄ pagal šiuos kriterijus:
389
+ Aktualumas: ATSAKYMAS turi būti tiesiogiai susijęs su KLAUSIMU ir konkrečiai į jį atsakyti.
390
+ Pakankamas: ATSAKYME turi būti pakankamai informacijos, kad būtų galima visapusiškai atsakyti į KLAUSIMĄ. Jei ATSAKYME vartojamos tokios frazės kaip „nežinau“, „neturiu pakankamai informacijos“, „pateiktuose dokumentuose apie tai neužsimenama“ ar panašių posakių, kuriuose vengiama tiesiogiai atsakyti į KLAUSIMĄ, įvertinkite „ne“.
391
+ Aiškumas ir glaustumas: ATSAKYMAS turi būti aiškus, be jokių nereikalingų frazių ar pasikartojimų. Jei jame yra perteklinė arba netiesioginė informacija, o ne tiesioginis atsakymas, įvertinkite „ne“.
392
+ Balų skaičiavimo instrukcijos:
393
+ „Taip“ reiškia, kad ATSAKYMAS atitinka visus šiuos kriterijus ir tiesiogiai susijęs su KLAUSIMU.
394
+ Įvertinimas „ne“ reiškia, kad ATSAKYMAS neatitinka visų šių kriterijų.
395
+ Jei randate tokio žodžio tekstą, kaip aš nežinau, nepakanka informacijos arba panašaus į šį, balas yra ne.
396
+ Pateikite balą kaip JSON su vienu raktu "balas" ir be papildomo teksto""",
397
+ input_variables=["generation", "question"]
398
+ )
399
+
400
+ # Combine the prompt with the structured LLM helpfulness checker
401
+ helpfulness_grader = prompt | structured_llm_helpfulness_checker
402
+
403
+ # Return the helpfulness checker object
404
+ return helpfulness_grader
405
+
406
+
407
+
408
+
409
+
410
+ def create_hallucination_checker(llm):
411
+ """
412
+ Function to create a hallucination checker object using a passed LLM model.
413
+
414
+ Args:
415
+ llm: The language model to be used for checking hallucinations in the student's answer.
416
+
417
+ Returns:
418
+ Callable: A pipeline function that checks if the student's answer contains hallucinations.
419
+ """
420
+
421
+
422
+ class hallucination_checker(BaseModel):
423
+ """Binary score for toxicity check on question."""
424
+ score: str = Field(
425
+ description="Ar dokumentas yra susijes su atsakymu?, 'taip' arba 'ne'"
426
+ )
427
+
428
+ # Create the structured LLM toxicity checker using the passed LLM
429
+
430
+
431
+
432
+ structured_llm_hallucination_checker = llm.with_structured_output(hallucination_checker)
433
+
434
+ # Define the prompt template
435
+ prompt = PromptTemplate(
436
+ template="""Jūs esate mokytojas, vertinantis viktoriną.
437
+ Jums bus pateikti FAKTAI ir MOKINIO ATSAKYMAS.
438
+ Jūs vertinate MOKINIO ATSAKYMĄ iš šaltinio FAKTAI. Sutelkite dėmesį į MOKINIO ATSAKYMO teisingumą ir bet kokių haliucinacijų aptikimą.
439
+ Įsitikinkite, kad MOKINIO ATSAKYMAS atitinka šiuos kriterijus:
440
+ (1) jame nėra informacijos, nesusijusios su FAKTAIS
441
+ (2) STUDENTŲ ATSAKYMAS turėtų būti visiškai pagrįstas ir pagrįstas pirminiuose dokumentuose pateikta informacija
442
+ Rezultatas:
443
+ „Taip“ reiškia, kad studento atsakymas atitinka visus kriterijus. Tai aukščiausias (geriausias) balas.
444
+ Balas „ne“ reiškia, kad studento atsakymas neatitinka visų kriterijų. Tai yra žemiausias galimas balas, kurį galite duoti.
445
+ Žingsnis po žingsnio paaiškinkite savo samprotavimus, kad įsitikintumėte, jog argumentai ir išvados yra teisingi.
446
+ Iš pradžių venkite tiesiog nurodyti teisingą atsakymą.
447
+ MOKINIO ATSAKYMAS: {generation} \n
448
+ FAKTAI: \n\n {documents} \n\n
449
+
450
+ Suteikite dvejetainį balą „taip“ arba „ne“, kad nurodytumėte, ar dokumentas yra susijęs su klausimu. \n
451
+ Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo.""",
452
+ input_variables=["generation", "documents"],
453
+ )
454
+
455
+ # Combine the prompt with the structured LLM hallucination checker
456
+ hallucination_grader = prompt | structured_llm_haliucinations_checker
457
+
458
+ # Return the hallucination checker object
459
+ return hallucination_grader
460
+
461
+
462
+ def create_question_rewriter(llm):
463
+ """
464
+ Function to create a question rewriter object using a passed LLM model.
465
+
466
+ Args:
467
+ llm: The language model to be used for rewriting questions.
468
+
469
+ Returns:
470
+ Callable: A pipeline function that rewrites questions for optimized vector store retrieval.
471
+ """
472
+
473
+ # Define the prompt template for question rewriting
474
+ re_write_prompt = PromptTemplate(
475
+ template="""Esate klausimų perrašytojas, kurio specializacija yra Lietuvos teisė, tobulinanti klausimus, kad būtų galima optimizuoti jų paiešką iš teisinių dokumentų. Jūsų tikslas – išaiškinti teisinę intenciją, pašalinti dviprasmiškumą ir pakoreguoti formuluotes taip, kad jos atspindėtų teisinę kalbą, daugiausia dėmesio skiriant atitinkamiems raktiniams žodžiams, siekiant užtikrinti tikslų informacijos gavimą iš Lietuvos teisės šaltinių.
476
+ Man nereikia paaiškinimų, tik perrašyto klausimo.
477
+ Štai pradinis klausimas: \n\n {question}. Patobulintas klausimas be paaiškinimų : \n""",
478
+ input_variables=["question"],
479
+ )
480
+
481
+ # Combine the prompt with the LLM and output parser
482
+ question_rewriter = re_write_prompt | llm | StrOutputParser()
483
+
484
+ # Return the question rewriter object
485
+ return question_rewriter
486
+
487
+
488
+ def transform_query(state, question_rewriter):
489
+ """
490
+ Transform the query to produce a better question.
491
+ Args:
492
+ state (dict): The current graph state
493
+ Returns:
494
+ state (dict): Updates question key with a re-phrased question
495
+ """
496
+
497
+ print("---TRANSFORM QUERY---")
498
+ question = state["question"]
499
+ documents = state["documents"]
500
+ steps = state["steps"]
501
+ steps.append("question_transformation")
502
+
503
+ # Re-write question
504
+ better_question = question_rewriter.invoke({"question": question})
505
+ print(f" Transformed question: {better_question}")
506
+ return {"documents": documents, "question": better_question}
507
+
508
+
509
+
510
+ def format_google_results_search(google_results):
511
+ formatted_documents = []
512
+
513
+ # Extract data from answerBox
514
+ answer_box = google_results.get("answerBox", {})
515
+ answer_box_title = answer_box.get("title", "No title")
516
+ answer_box_answer = answer_box.get("answer", "No text")
517
+
518
+
519
+
520
+
521
+
522
+ # Extract and add organic results as separate Documents
523
+ for result in google_results.get("organic", []):
524
+ title = result.get("title", "No title")
525
+ link = result.get("link", "Nėra svetainės adreso")
526
+ snippet = result.get("snippet", "No snippet available")
527
+
528
+
529
+ document = Document(
530
+ metadata={
531
+ "Organinio rezultato pavadinimas": title,
532
+
533
+ },
534
+ page_content=(
535
+ f"Pavadinimas: {title} "
536
+ f"Straipsnio ištrauka: {snippet} "
537
+ f"Nuoroda: {link} "
538
+
539
+ )
540
+ )
541
+ formatted_documents.append(document)
542
+
543
+ return formatted_documents
544
+
545
+
546
+
547
+ def format_google_results_news(google_results):
548
+ formatted_documents = []
549
+
550
+ # Loop through each organic result and create a Document for it
551
+ for result in google_results['organic']:
552
+ title = result.get('title', 'No title')
553
+ link = result.get('link', 'No link')
554
+ descripsion = result.get('description', 'No link')
555
+ snippet = result.get('snippet', 'No summary available')
556
+ text = result.get('text' , 'no text')
557
+
558
+ # Create a Document object with similar metadata structure to WikipediaRetriever
559
+ document = Document(
560
+ metadata={
561
+ 'Title': title,
562
+ 'Description': descripsion,
563
+ 'Text' : text,
564
+ 'Snippet': snippet,
565
+ 'Source': link
566
+ },
567
+ page_content=snippet # Using the snippet as the page content
568
+ )
569
+
570
+ formatted_documents.append(document)
571
+
572
+ return formatted_documents
573
+
574
+
575
+ def QA_chain(llm):
576
+ """
577
+ Creates a question-answering chain using the provided language model.
578
+ Args:
579
+ llm: The language model to use for generating answers.
580
+ Returns:
581
+ An LLMChain configured with the question-answering prompt and the provided model.
582
+ """
583
+ # Define the prompt template
584
+ prompt = PromptTemplate(
585
+ template="""Esi teisės asistentas, kurio užduotis yra atsakyti konkrečiai, informatyviai ir glaustai , pagrindžiant savo atsakymą į klausima pagal pateiktus dokumentus.
586
+ Atsakymas turi būti lietuvių kalba. Nesikartok.
587
+ Jei negali atsakyti į klausimą, pasakyk, Atsiprašau, nežinau atsakymo į jūsų klausimą.
588
+ Neužduok papildomų klausimų.
589
+
590
+ Klausimas: {question}
591
+ Dokumentai: {documents}
592
+ Atsakymas:
593
+ """,
594
+ input_variables=["question", "documents"],
595
+ )
596
+
597
+
598
+ rag_chain = prompt | llm | StrOutputParser()
599
+
600
+
601
+ return rag_chain
602
+
603
+
604
+ def grade_generation_v_documents_and_question(state,hallucination_grader,answer_grader ):
605
+ """
606
+ Determines whether the generation is grounded in the document and answers the question.
607
+ """
608
+ print("---CHECK HALLUCINATIONS---")
609
+ question = state["question"]
610
+ documents = state["documents"]
611
+ generation = state["generation"]
612
+ generation_count = state.get("generation_count") # Use state.get to avoid KeyError
613
+ print(f" generation number: {generation_count}")
614
+
615
+ # Grading hallucinations
616
+ score = hallucination_grader.invoke(
617
+ {"documents": documents, "generation": generation}
618
+ )
619
+ grade = getattr(score, 'score', None)
620
+
621
+ # Check hallucination
622
+ if grade == "yes":
623
+ print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
624
+ # Check question-answering
625
+ print("---GRADE GENERATION vs QUESTION---")
626
+ score = answer_grader.invoke({"question": question, "generation": generation})
627
+ grade = getattr(score, 'score', None)
628
+ if grade == "yes":
629
+ print("---DECISION: GENERATION ADDRESSES QUESTION---")
630
+ return "useful"
631
+ else:
632
+ print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
633
+ return "not useful"
634
+ else:
635
+ if generation_count > 1:
636
+ print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, TRANSFORM QUERY---")
637
+ # Reset count if it exceeds limit
638
+ return "not useful"
639
+ else:
640
+ print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
641
+ # Increment correctly here
642
+ print(f" generation number after increment: {state['generation_count']}")
643
+ return "not supported"
644
+
645
+
646
+ def ask_question(state):
647
+ """
648
+ Initialize question
649
+ Args:
650
+ state (dict): The current graph state
651
+ Returns:
652
+ state (dict): Question
653
+ """
654
+ steps = state["steps"]
655
+ question = state["question"]
656
+ generations_count = state.get("generations_count", 0)
657
+
658
+
659
+ steps.append("question_asked")
660
+ return {"question": question, "steps": steps,"generation_count": generations_count}
661
+
662
+
663
+ def retrieve(state , retriever):
664
+ """
665
+ Retrieve documents
666
+ Args:
667
+ state (dict): The current graph state
668
+ retriever: The retriever object
669
+ Returns:
670
+ state (dict): New key added to state, documents, that contains retrieved documents
671
+ """
672
+ steps = state["steps"]
673
+ question = state["question"]
674
+
675
+ documents = retriever.invoke(question)
676
+
677
+ steps.append("retrieve_documents")
678
+ return {"documents": documents, "question": question, "steps": steps}
679
+
680
+
681
+ def generate(state,QA_chain):
682
+ """
683
+ Generate answer
684
+ """
685
+ question = state["question"]
686
+ documents = state["documents"]
687
+ generation = QA_chain.stream({"documents": documents, "question": question})
688
+ steps = state["steps"]
689
+ steps.append("generate_answer")
690
+ generation_count = state["generation_count"]
691
+
692
+ generation_count += 1
693
+
694
+ return {
695
+ "documents": documents,
696
+ "question": question,
697
+ "generation": generation,
698
+ "steps": steps,
699
+ "generation_count": generation_count # Include generation_count in return
700
+ }
701
+
702
+
703
+ def grade_documents(state, retrieval_grader):
704
+ question = state["question"]
705
+ documents = state["documents"]
706
+ steps = state["steps"]
707
+ steps.append("grade_document_retrieval")
708
+
709
+ filtered_docs = []
710
+ web_results_list = []
711
+ search = "No"
712
+
713
+ for d in documents:
714
+ # Call the grading function
715
+ score = retrieval_grader.invoke({"question": question, "documents": d})
716
+ print(f"Grader output for document: {score}") # Detailed debugging output
717
+
718
+ # Extract the grade
719
+ grade = getattr(score, 'binary_score', None)
720
+ if grade and grade.lower() in ["yes", "true", "1",'taip']:
721
+ filtered_docs.append(d)
722
+ elif len(filtered_docs) < 4:
723
+ search = "Yes"
724
+
725
+ # Check the decision-making process
726
+ print(f"Final decision - Perform web search: {search}")
727
+ print(f"Filtered documents count: {len(filtered_docs)}")
728
+
729
+ return {
730
+ "documents": filtered_docs,
731
+ "question": question,
732
+ "search": search,
733
+ "steps": steps,
734
+ }
735
+
736
+ def clean_exa_document(doc):
737
+ """
738
+ Extracts and retains only the title, url, text, and summary from the exa result document.
739
+ """
740
+ return {
741
+ " Pavadinimas: ": doc.title,
742
+ " Apibendrinimas: ": doc.summary,
743
+ " Straipnsio internetinis adresas: ": doc.url,
744
+ " Tekstas: ": doc.text
745
+
746
+ }
747
+
748
+ def web_search(state):
749
+ question = state["question"]
750
+ documents = state.get("documents", [])
751
+ steps = state["steps"]
752
+ steps.append("web_search")
753
+ k = 8 - len(documents)
754
+ web_results_list = []
755
+
756
+ # Fetch results from exa
757
+ exa_results_raw = exa.search_and_contents(
758
+ query=question,
759
+ start_published_date="2018-01-01T22:00:01.000Z",
760
+
761
+ type="keyword",
762
+ num_results=2,
763
+ text={"max_characters": 7000},
764
+ summary={
765
+ "query": "Tell in summary a meaning about what is article written. This summary has to be written in a way to be related to {question} Provide facts, be concise. Do it in Lithuanian language."
766
+ },
767
+ include_domains=[ "infolex.lt", "vmi.lt", "lrs.lt", "e-seimas.lrs.lt", "teise.pro",'lt.wikipedia.org', 'teismai.lt' ],
768
+
769
+ )
770
+ # Extract results
771
+ exa_results = exa_results_raw.results if hasattr(exa_results_raw, "results") else []
772
+ cleaned_exa_results = [clean_exa_document(doc) for doc in exa_results]
773
+
774
+ if len(cleaned_exa_results) <1:
775
+ web_results = GoogleSerperAPIWrapper(k=2, gl="lt", hl="lt", type="search").results(question)
776
+ formatted_documents = format_google_results_search(web_results)
777
+ web_results_list.extend(formatted_documents if isinstance(formatted_documents, list) else [formatted_documents])
778
+
779
+ combined_documents = documents + cleaned_exa_results +web_results_list
780
+
781
+
782
+ else:
783
+ combined_documents = documents + cleaned_exa_results
784
+
785
+
786
+
787
+
788
+
789
+
790
+
791
+
792
+ return {"documents": combined_documents, "question": question, "steps": steps}
793
+
794
+ def decide_to_generate(state):
795
+ """
796
+ Determines whether to generate an answer, or re-generate a question.
797
+ Args:
798
+ state (dict): The current graph state
799
+ Returns:
800
+ str: Binary decision for next node to call
801
+ """
802
+ search = state["search"]
803
+ if search == "Yes":
804
+ return "search"
805
+ else:
806
+ return "generate"
807
+
808
+ def decide_to_generate2(state):
809
+ """
810
+ Determines whether to generate an answer, or re-generate a question.
811
+ Args:
812
+ state (dict): The current graph state
813
+ Returns:
814
+ str: Binary decision for next node to call
815
+ """
816
+ search = state["search"]
817
+ if search == "Yes":
818
+ return "search"
819
+ else:
820
+ return "generate"