File size: 5,035 Bytes
643e1b9
 
aac5496
643e1b9
 
 
 
8bce767
643e1b9
8c678cf
ac12a64
b7aed3a
b8c06a5
465bc79
ecc789c
 
465bc79
974c8b8
1275101
b210fbe
b910146
ac12a64
1275101
231b62a
643e1b9
0467f17
643e1b9
 
b7a41e7
aac5496
baf000f
6130d38
 
0467f17
baf000f
643e1b9
650c39a
b277c0d
d3df8fd
b8c06a5
 
 
 
 
 
8bce767
 
b8c06a5
 
 
 
8bce767
 
 
643e1b9
 
708da42
 
 
 
b52ede2
b8c06a5
08c9e9f
b8c06a5
08c9e9f
5592cea
08c9e9f
5592cea
08c9e9f
8bce767
 
 
08c9e9f
 
643e1b9
a5cb440
6ed7896
b8c06a5
8bce767
f7aeb1e
2c6a0aa
a5cb440
3ef1210
 
 
 
 
 
 
 
8bce767
b8c06a5
 
 
 
 
 
 
8bce767
 
 
 
3ef1210
8bce767
170f218
8bce767
 
3ef1210
 
f16b8b5
 
3ef1210
f16b8b5
 
8bce767
 
 
 
 
 
 
 
 
 
 
 
 
 
93baf7b
ed51056
86b68c0
 
8bce767
 
f7aeb1e
8bce767
 
74bd69e
8bce767
86b68c0
8bce767
ed51056
0328982
8238f47
643e1b9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage, StorageContext
from llama_index.core.node_parser import SentenceSplitter
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer
from typing import Iterator, List, Any
from llama_index.core.chat_engine import CondensePlusContextChatEngine
from llama_index.core.llms import ChatMessage, MessageRole





huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto", 
    torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    token=True)

model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model.eval()

# what models will be used by LlamaIndex:
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface()

documents_paths = {
    'blockchain': 'data/blockchainprova.txt',
    'metaverse': 'data/metaverso',
    'payment': 'data/payment'
}

session_state = {"index": False,
                 "documents_loaded": False, 
                 "document_db": None, 
                 "original_message": None, 
                 "clarification": False}

PERSIST_DIR = "./db"
os.makedirs(PERSIST_DIR, exist_ok=True)

############################---------------------------------

# Get the parser
parser = SentenceSplitter.from_defaults(
                chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
            )

def build_index(path: str):
    # Load documents from a file
    documents = SimpleDirectoryReader(input_files=[path]).load_data()
    # Parse the documents into nodes
    nodes = parser.get_nodes_from_documents(documents)
    # Build the vector store index from the nodes
    index = VectorStoreIndex(nodes)
    
    storage_context = StorageContext.from_defaults()
    index.storage_context.persist(persist_dir=PERSIST_DIR)
    
    return index



@spaces.GPU(duration=20)
def handle_query(query_str: str, 
                 chat_history: list[tuple[str, str]]) -> Iterator[str]:
    
    #global conversation
    
    conversation: List[ChatMessage] = []
    for user, assistant in chat_history:
      conversation.extend([
      ChatMessage(role=MessageRole.USER, content=user),
      ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
                    ]
                )
    
    if not session_state["index"]:
        matched_path = None
        words = query_str.lower()
        for key, path in documents_paths.items():
            if key in words:
                matched_path = path
                break
        if matched_path:
            index = build_index(matched_path)
            session_state["index"] = True
            
        else: ## CHIEDI CHIARIMENTO
            
            index = build_index("data/chiarimento.txt")
                  
    else:
        # The index is already built, no need to rebuild it.
        storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
        index = load_index_from_storage(storage_context)    

    try:
        
        memory = ChatMemoryBuffer.from_defaults(token_limit=None)
        
        chat_engine = index.as_chat_engine(
        chat_mode="condense_plus_context",
        memory=memory,
        similarity_top_k=4, 
        response_mode="tree_summarize", #Good for summarization purposes

        context_prompt = (
        "Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
        " Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto."
        " Ecco i documenti rilevanti per il contesto:\n"
        "{context_str}"
        "\nIstruzione: Usa la cronologia delle chat precedenti, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda."
          ),
        verbose=False,
        )
        
        
        outputs = []
        response = chat_engine.stream_chat(query_str, conversation)
        #response = chat_engine.chat(query_str)
        for token in response.response_gen:
          #if not token.startswith("system:") and not token.startswith("user:"):
          
            outputs.append(token)
            #print(f"Generated token: {token}")
            yield "".join(outputs)
          
        
    except Exception as e:
        yield f"Error processing query: {str(e)}"