nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
"""A script to load a dataset or set of datasets from a config for a Lilac instance.
Usage:
poetry run python -m lilac.load \
--output_dir=demo_data \
--config_path=lilac_hf_space.yml
"""
import gc
import os
import shutil
import click
import dask
import psutil
from distributed import Client
from .concepts.db_concept import DiskConceptDB, DiskConceptModelDB
from .config import EmbeddingConfig, SignalConfig, read_config
from .data.dataset_duckdb import DatasetDuckDB
from .data_loader import process_source
from .db_manager import get_dataset, list_datasets, remove_dataset_from_cache
from .schema import ROWID, PathTuple
from .tasks import TaskManager, TaskStepId
from .utils import DebugTimer, get_datasets_dir
@click.command()
@click.option(
'--output_dir', required=True, type=str, help='The output directory to write files to.')
@click.option(
'--config_path',
required=True,
type=str,
help='The path to a json or yml file describing the configuration. '
'The file contents should be an instance of `lilac.Config` or `lilac.DatasetConfig`.')
@click.option(
'--overwrite',
help='When True, runs all all data from scratch, overwriting existing data. When false, only'
'load new datasets, embeddings, and signals.',
type=bool,
is_flag=True,
default=False)
def load_command(output_dir: str, config_path: str, overwrite: bool) -> None:
"""Run the source loader as a binary."""
load(output_dir, config_path, overwrite)
def load(output_dir: str, config_path: str, overwrite: bool) -> None:
"""Run the source loader as a binary."""
old_data_path = os.environ.get('LILAC_DATA_PATH')
os.environ['LILAC_DATA_PATH'] = output_dir
# Turn off debug logging.
del os.environ['DEBUG']
# Use views to avoid loading duckdb tables into RAM since we aren't query heavy.
os.environ['DUCKDB_USE_VIEWS'] = '1'
config = read_config(config_path)
# Explicitly create a dask client in sync mode.
dask.config.set({'distributed.worker.daemon': False})
total_memory_gb = psutil.virtual_memory().total / (1024**3) * 2 / 3
task_manager = TaskManager(Client(memory_limit=f'{total_memory_gb} GB'))
if overwrite:
shutil.rmtree(get_datasets_dir(output_dir), ignore_errors=True)
existing_datasets = [f'{d.namespace}/{d.dataset_name}' for d in list_datasets(output_dir)]
print()
print('*** Load datasets ***')
if overwrite:
datasets_to_load = config.datasets
else:
datasets_to_load = [
d for d in config.datasets if f'{d.namespace}/{d.name}' not in existing_datasets
]
skipped_datasets = [
d for d in config.datasets if f'{d.namespace}/{d.name}' in existing_datasets
]
print('Skipping loaded datasets:', ', '.join([d.name for d in skipped_datasets]))
with DebugTimer(f'Loading datasets: {", ".join([d.name for d in datasets_to_load])}'):
for d in datasets_to_load:
shutil.rmtree(os.path.join(output_dir, d.name), ignore_errors=True)
task_id = task_manager.task_id(f'Load dataset {d.namespace}/{d.name}')
task_manager.execute(task_id, process_source, output_dir, d, (task_id, 0))
task_manager.wait()
print()
total_num_rows = 0
for d in datasets_to_load:
dataset = DatasetDuckDB(d.namespace, d.name)
num_rows = dataset.select_rows([ROWID], limit=1).total_num_rows
print(f'{d.namespace}/{d.name} loaded with {num_rows:,} rows.')
# Free up RAM.
del dataset
total_num_rows += num_rows
print(f'Done loading {len(datasets_to_load)} datasets with {total_num_rows:,} rows.')
print('*** Dataset settings ***')
for d in config.datasets:
if d.settings:
dataset = DatasetDuckDB(d.namespace, d.name)
dataset.update_settings(d.settings)
print()
print('*** Compute embeddings ***')
with DebugTimer('Loading embeddings'):
for d in config.datasets:
dataset = DatasetDuckDB(d.namespace, d.name)
# If embeddings are explicitly set, use only those.
embeddings = d.embeddings or []
# If embeddings are not explicitly set, use the media paths and preferred embedding from
# settings.
if not embeddings:
if d.settings and d.settings.ui:
for path in d.settings.ui.media_paths or []:
if d.settings.preferred_embedding:
embeddings.append(
EmbeddingConfig(path=path, embedding=d.settings.preferred_embedding))
for e in embeddings:
if e not in dataset.config().embeddings:
print('scheduling', e)
task_id = task_manager.task_id(f'Compute embedding {e.embedding} on {d.name}:{e.path}')
task_manager.execute(task_id, _compute_embedding, d.namespace, d.name, e, output_dir,
overwrite, (task_id, 0))
else:
print(f'Embedding {e.embedding} already exists for {d.name}:{e.path}. Skipping.')
del dataset
# Wait for all embeddings for each dataset to reduce the memory pressure.
task_manager.wait()
print()
print('*** Compute signals ***')
with DebugTimer('Computing signals'):
for d in config.datasets:
dataset = DatasetDuckDB(d.namespace, d.name)
# If signals are explicitly set, use only those.
signals = d.signals or []
# If signals are not explicitly set, use the media paths and config.signals.
if not signals:
if d.settings and d.settings.ui:
for path in d.settings.ui.media_paths or []:
for signal in config.signals or []:
signals.append(SignalConfig(path=path, signal=signal))
# Separate signals by path to avoid computing the same signal in parallel, which can cause
# issues with taking too much RAM.
path_signals: dict[PathTuple, list[SignalConfig]] = {}
for s in signals:
path_signals.setdefault(s.path, []).append(s)
for path, signals in path_signals.items():
for s in signals:
if s not in dataset.config().signals:
task_id = task_manager.task_id(f'Compute signal {s.signal} on {d.name}:{s.path}')
task_manager.execute(task_id, _compute_signal, d.namespace, d.name, s, output_dir,
overwrite, (task_id, 0))
else:
print(f'Signal {s.signal} already exists for {d.name}:{s.path}. Skipping.')
# Wait for all signals for each path to reduce the memory pressure.
task_manager.wait()
del dataset
print()
print('*** Compute model caches ***')
with DebugTimer('Computing model caches'):
concept_db = DiskConceptDB(output_dir)
concept_model_db = DiskConceptModelDB(concept_db)
if config.concept_model_cache_embeddings:
for concept_info in concept_db.list():
for embedding in config.concept_model_cache_embeddings:
concept_model_db.sync(
concept_info.namespace, concept_info.name, embedding_name=embedding, create=True)
print()
print('Done!')
if old_data_path:
os.environ['LILAC_DATA_PATH'] = old_data_path
def _compute_signal(namespace: str, name: str, signal_config: SignalConfig, output_dir: str,
overwrite: bool, task_step_id: TaskStepId) -> None:
os.environ['LILAC_DATA_PATH'] = output_dir
os.environ['DUCKDB_USE_VIEWS'] = '1'
# Turn off debug logging.
if 'DEBUG' in os.environ:
del os.environ['DEBUG']
compute_signal = False
if overwrite:
compute_signal = True
dataset = get_dataset(namespace, name)
if not compute_signal:
field = dataset.manifest().data_schema.get_field(signal_config.path)
signal_field = (field.fields or {}).get(signal_config.signal.key())
if not signal_field or signal_field.signal != signal_config.signal.dict():
compute_signal = True
if compute_signal:
dataset.compute_signal(signal_config.signal, signal_config.path, task_step_id)
# Free up RAM.
remove_dataset_from_cache(namespace, name)
del dataset
gc.collect()
def _compute_embedding(namespace: str, name: str, embedding_config: EmbeddingConfig,
output_dir: str, overwrite: bool, task_step_id: TaskStepId) -> None:
os.environ['LILAC_DATA_PATH'] = output_dir
os.environ['DUCKDB_USE_VIEWS'] = '1'
# Turn off debug logging.
if 'DEBUG' in os.environ:
del os.environ['DEBUG']
compute_embedding = False
if overwrite:
compute_embedding = True
dataset = get_dataset(namespace, name)
if not compute_embedding:
field = dataset.manifest().data_schema.get_field(embedding_config.path)
embedding_field = (field.fields or {}).get(embedding_config.embedding)
if not embedding_field:
compute_embedding = True
if compute_embedding:
dataset.compute_embedding(embedding_config.embedding, embedding_config.path, task_step_id)
remove_dataset_from_cache(namespace, name)
del dataset
gc.collect()
if __name__ == '__main__':
load_command()