gufett0 commited on
Commit
643e1b9
1 Parent(s): 57ae88a

first app files

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. app.py +15 -0
  3. backend.py +87 -0
  4. data/blockchainprova.txt +0 -0
  5. interface.py +44 -0
  6. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /myenv
app.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend import handle_query
2
+ import gradio as gr
3
+
4
+
5
+ iface = gr.ChatInterface(
6
+ fn=handle_query,
7
+ title="PDF Information and Inference",
8
+ description="Retrieval-Augmented Generation - Ask me anything about the content of the PDF.",
9
+ #examples=["What is the main topic of the document?", "Can you summarize the key points?"],
10
+ #cache_examples=True,
11
+ )
12
+
13
+
14
+ if __name__ == "__main__":
15
+ iface.launch()
backend.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
4
+ from interface import GemmaLLMInterface
5
+ from llama_index.core.node_parser import SentenceSplitter
6
+ from llama_index.embeddings.instructor import InstructorEmbedding
7
+ import gradio as gr
8
+ from llama_index.core import ChatPromptTemplate
9
+ from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
10
+ from llama_index.core.node_parser import SentenceSplitter
11
+
12
+
13
+ model_id = "google/gemma-2-2b-it"
14
+ tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_id,
17
+ device_map="auto",
18
+ torch_dtype= torch.float16 if torch.cuda.is_available() else torch.float32,
19
+ )
20
+ # what models will be used by LlamaIndex:
21
+ Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
22
+ Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)
23
+
24
+
25
+ """os.environ["KAGGLE_USERNAME"] = "middi0"
26
+ os.environ["KAGGLE_KEY"] = "b7eed1ea5cfb30e8eb13b085af2e427b"
27
+
28
+ # Let's load Gemma using Keras
29
+ gemma_model_id = "gemma2_instruct_2b_en"
30
+ gemma = keras_nlp.models.GemmaCausalLM.from_preset(gemma_model_id)
31
+
32
+ # This settings define what models will be used by LlamaIndex
33
+ Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
34
+ Settings.llm = GemmaLLMInterface(model=gemma)"""
35
+
36
+
37
+ ############################---------------------------------
38
+
39
+ # CHUNKING
40
+ # Reading documents from disk
41
+ documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data()
42
+
43
+ # Splitting the document into chunks with
44
+ # predefined size and overlap
45
+ parser = SentenceSplitter.from_defaults(
46
+ chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
47
+ )
48
+ nodes = parser.get_nodes_from_documents(documents)
49
+ #print(nodes[6].text)
50
+
51
+ # BUILD A VECTOR STORE
52
+ index = VectorStoreIndex(nodes)
53
+
54
+
55
+ def handle_query(query_str, chathistory):
56
+
57
+ qa_prompt_str = (
58
+ "Context information is below.\n"
59
+ "---------------------\n"
60
+ "{context_str}\n"
61
+ "---------------------\n"
62
+ "Given the context information and not prior knowledge, "
63
+ "answer the question: {query_str}\n"
64
+ )
65
+
66
+ # Text QA Prompt
67
+ chat_text_qa_msgs = [
68
+ (
69
+ "system",
70
+ "Sei un assistente italiano di nome Tizio che risponde solo alle domande o richieste pertinenti. ",
71
+ ),
72
+ ("user", qa_prompt_str),
73
+ ]
74
+ text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
75
+
76
+ index = VectorStoreIndex(nodes)
77
+
78
+ result = index.as_query_engine(text_qa_template=text_qa_template).query(query_str)
79
+ response_text = result.response
80
+
81
+ # Remove any unwanted tokens like <end_of_turn>
82
+ cleaned_result = response_text.replace("<end_of_turn>", "").strip()
83
+
84
+ yield cleaned_result
85
+
86
+
87
+
data/blockchainprova.txt ADDED
The diff for this file is too large to render. See raw diff
 
interface.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
3
+ from llama_index.core.llms.callbacks import llm_completion_callback
4
+ from typing import Any
5
+
6
+
7
+ class GemmaLLMInterface(CustomLLM):
8
+ model: Any
9
+ tokenizer: Any
10
+ context_window: int = 8192
11
+ num_output: int = 2048
12
+ model_name: str = "gemma_2"
13
+
14
+ class Config:
15
+ protected_namespaces = ()
16
+
17
+ def _format_prompt(self, message: str) -> str:
18
+ return (
19
+ f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
20
+ )
21
+
22
+ @property
23
+ def metadata(self) -> LLMMetadata:
24
+ #Get LLM metadata.
25
+ return LLMMetadata(
26
+ context_window=self.context_window,
27
+ num_output=self.num_output,
28
+ model_name=self.model_name,
29
+ )
30
+
31
+ @llm_completion_callback()
32
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
33
+ prompt = self._format_prompt(prompt)
34
+ inputs = self.tokenizer(prompt, return_tensors="pt")
35
+ output = self.model.generate(**inputs, max_length=self.num_output)
36
+ raw_response = self.tokenizer.decode(output[0], skip_special_tokens=True)
37
+ response = raw_response[len(prompt):]
38
+ return CompletionResponse(text=response)
39
+
40
+ @llm_completion_callback()
41
+ def stream_complete(self, prompt: str, **kwargs: any) -> CompletionResponseGen:
42
+ response = self.complete(prompt).text
43
+ for token in response:
44
+ yield CompletionResponse(text=token)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ llama-index
3
+ llama-index-embeddings-huggingface
4
+ llama-index-llms-huggingface
5
+ llama-index-embeddings-instructor
6
+ sentence-transformers==2.2.2
7
+ llama-index-readers-web
8
+ llama-index-readers-file
9
+ gradio
10
+ transformers