File size: 4,085 Bytes
643e1b9
 
aac5496
643e1b9
 
 
 
 
 
 
1275101
 
8c678cf
ac12a64
b7aed3a
1275101
b210fbe
b910146
ac12a64
1275101
231b62a
643e1b9
c611543
3f367eb
643e1b9
 
baf000f
aac5496
baf000f
c611543
baf000f
 
 
 
643e1b9
f7aeb1e
 
643e1b9
f7aeb1e
643e1b9
 
 
708da42
 
 
 
b52ede2
08c9e9f
 
5592cea
08c9e9f
5592cea
08c9e9f
5592cea
08c9e9f
 
 
643e1b9
03d2fc2
f7aeb1e
 
ed51056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7aeb1e
ed51056
 
 
 
 
 
f7aeb1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed51056
 
f7aeb1e
 
 
 
 
ed51056
 
9a196a8
f7aeb1e
 
 
 
ed51056
0328982
8238f47
643e1b9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import ChatPromptTemplate
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
from llama_index.core.node_parser import SentenceSplitter
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer


huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

"""model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto", ## change this back to auto!!!
    torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    token=True)
model.eval()"""

#from accelerate import disk_offload
#disk_offload(model=model, offload_dir="offload")

# what models will be used by LlamaIndex:
"""Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm  = GemmaLLMInterface(model=model, tokenizer=tokenizer)"""
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it")

############################---------------------------------

# Get the parser
parser = SentenceSplitter.from_defaults(
                chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
            )

def build_index():
    # Load documents from a file
    documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data()
    # Parse the documents into nodes
    nodes = parser.get_nodes_from_documents(documents)
    # Build the vector store index from the nodes
    index = VectorStoreIndex(nodes)
    
    return index


@spaces.GPU(duration=20)
def handle_query(query_str, chathistory):
    
    index = build_index()
  
    qa_prompt_str = (
        "Context information is below.\n"
        "---------------------\n"
        "{context_str}\n"
        "---------------------\n"
        "Given the context information and not prior knowledge, "
        "answer the question: {query_str}\n"
    )

    # Text QA Prompt
    chat_text_qa_msgs = [
        (
            "system",
            "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
        ),
        ("user", qa_prompt_str),
    ]
    text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)

    try:
        # Create a streaming query engine
        """query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1)
        
        # Execute the query
        streaming_response = query_engine.query(query_str)
        
        r = streaming_response.response
        cleaned_result = r.replace("<end_of_turn>", "").strip()
        yield cleaned_result"""
      
        # Stream the response
        """outputs = []
        for text in streaming_response.response_gen:
            
            outputs.append(str(text))
            yield "".join(outputs)"""
          
        memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
        chat_engine = index.as_chat_engine(
        chat_mode="context",
        memory=memory,
        system_prompt=(
            "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. "
          ),
        )
        
        response = chat_engine.stream_chat(query_str)
        #response = chat_engine.chat(query_str)
        for token in response.response_gen:
            yield token
          
        
    except Exception as e:
        yield f"Error processing query: {str(e)}"