File size: 769 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import os
from llm_studio.src.datasets.text_utils import get_tokenizer
from llm_studio.src.plots.text_causal_language_modeling_plots import (
Plots as TextCausalLanguageModelingPlots,
)
from llm_studio.src.plots.text_causal_language_modeling_plots import (
create_batch_prediction_df,
)
from llm_studio.src.utils.plot_utils import PlotData
class Plots(TextCausalLanguageModelingPlots):
@classmethod
def plot_batch(cls, batch, cfg) -> PlotData:
tokenizer = get_tokenizer(cfg)
df = create_batch_prediction_df(
batch, tokenizer, ids_for_tokenized_text="prompt_input_ids"
)
path = os.path.join(cfg.output_directory, "batch_viz.parquet")
df.to_parquet(path)
return PlotData(path, encoding="df")
|