File size: 1,756 Bytes
b4f5da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    column,
)
import os
from llama_index.core import Settings, VectorStoreIndex
from llama_index.core import SQLDatabase
from llama_index.llms.ollama import Ollama
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme({
    "title": "bold white on orchid1",
    "text": "dim chartreuse1",
})

console = Console(theme=custom_theme)

Settings.llm = Ollama(model="phi3", request_timeout=360.0)
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")

engine = create_engine("sqlite:///multi-agents-analysis/data/laps.db")
metadata_obj = MetaData()

sql_database = SQLDatabase(engine)

# manually set extra context text
city_stats_text = """This table gives information regarding the performance in a race about each driver.
The time is split into 3 different sectors.
The speed is split into SpeedI1, SpeedI2, SpeedFL and SpeedST"""

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="laps", context_str=city_stats_text))
]

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping
)

query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)
response = query_engine.query("Which driver had the lowers time in sector 1?")
print(response)