File size: 769 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os

from llm_studio.src.datasets.text_utils import get_tokenizer
from llm_studio.src.plots.text_causal_language_modeling_plots import (
    Plots as TextCausalLanguageModelingPlots,
)
from llm_studio.src.plots.text_causal_language_modeling_plots import (
    create_batch_prediction_df,
)
from llm_studio.src.utils.plot_utils import PlotData


class Plots(TextCausalLanguageModelingPlots):
    @classmethod
    def plot_batch(cls, batch, cfg) -> PlotData:
        tokenizer = get_tokenizer(cfg)
        df = create_batch_prediction_df(
            batch, tokenizer, ids_for_tokenized_text="prompt_input_ids"
        )
        path = os.path.join(cfg.output_directory, "batch_viz.parquet")
        df.to_parquet(path)
        return PlotData(path, encoding="df")