File size: 2,786 Bytes
643e1b9
 
aac5496
643e1b9
 
 
 
 
 
 
1275101
 
8c678cf
ac12a64
1275101
b210fbe
b910146
ac12a64
1275101
231b62a
643e1b9
0865501
3f367eb
643e1b9
 
 
aac5496
559bfa7
0865501
231b62a
643e1b9
 
0865501
 
643e1b9
 
 
 
708da42
 
 
 
b52ede2
643e1b9
 
03d2fc2
643e1b9
92eca27
708da42
a86bac6
 
 
 
643e1b9
b52ede2
 
 
 
 
 
643e1b9
 
b52ede2
643e1b9
b52ede2
 
 
 
 
643e1b9
 
1c8dd0f
8238f47
 
0328982
643e1b9
0328982
 
643e1b9
0328982
 
8238f47
643e1b9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import ChatPromptTemplate
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
from llama_index.core.node_parser import SentenceSplitter
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import spaces
from huggingface_hub import login


huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    token=True
)
model.eval()
# what models will be used by LlamaIndex:
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm  = GemmaLLMInterface(model=model, tokenizer=tokenizer)
#Settings.llm  = llm


############################---------------------------------

# Get the parser
parser = SentenceSplitter.from_defaults(
                chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
            )



@spaces.GPU(duration=20)
def handle_query(query_str, chathistory):
  
    # build the vector
    documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data()
    nodes = parser.get_nodes_from_documents(documents)
    index = VectorStoreIndex(nodes)
  
    qa_prompt_str = (
        "Context information is below.\n"
        "---------------------\n"
        "{context_str}\n"
        "---------------------\n"
        "Given the context information and not prior knowledge, "
        "answer the question: {query_str}\n"
    )

    # Text QA Prompt
    chat_text_qa_msgs = [
        (
            "system",
            "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
        ),
        ("user", qa_prompt_str),
    ]
    text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)

    try:
        result = index.as_query_engine(text_qa_template=text_qa_template, streaming=True).query(query_str)
        response_text = result.response

        # Remove any unwanted tokens like <end_of_turn>
        cleaned_result = response_text.replace("<end_of_turn>", "").strip()

        yield cleaned_result
    except Exception as e:
        yield f"Error processing query: {str(e)}"