File size: 3,250 Bytes
ce00289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import uuid
from typing import List, Optional, Tuple

import networkx as nx
import streamlit as st
import torch
import transformers

import llm_transparency_tool.routes.graph
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
from llm_transparency_tool.models.transparent_llm import TransparentLlm

GPU = "gpu"
CPU = "cpu"

# This variable is for expressing the idea that batch_id = 0, but make it more
# readable than just 0.
B0 = 0


def possible_devices() -> List[str]:
    devices = []
    if torch.cuda.is_available():
        devices.append("gpu")
    devices.append("cpu")
    return devices


def load_dataset(filename) -> List[str]:
    with open(filename) as f:
        dataset = [s.strip("\n") for s in f.readlines()]
    print(f"Loaded {len(dataset)} sentences from {filename}")
    return dataset


@st.cache_resource(
    hash_funcs={
        TransformerLensTransparentLlm: id
    }
)
def load_model(
    model_name: str,
    _device: str,
    _model_path: Optional[str] = None,
    _dtype: torch.dtype = torch.float32,
) -> TransparentLlm:
    """
    Returns the loaded model along with its key. The key is just a unique string which
    can be used later to identify if the model has changed.
    """
    assert _device in possible_devices()

    causal_lm = None
    tokenizer = None

    tl_lm = TransformerLensTransparentLlm(
        model_name=model_name,
        hf_model=causal_lm,
        tokenizer=tokenizer,
        device=_device,
        dtype=_dtype,
    )

    return tl_lm


def run_model(model: TransparentLlm, sentence: str) -> None:
    print(f"Running inference for '{sentence}'")
    model.run([sentence])


def load_model_with_session_caching(
    **kwargs,
) -> Tuple[TransparentLlm, str]:
    return load_model(**kwargs)

def run_model_with_session_caching(
    _model: TransparentLlm,
    model_key: str,
    sentence: str,
):
    LAST_RUN_MODEL_KEY = "last_run_model_key"
    LAST_RUN_SENTENCE = "last_run_sentence"
    state = st.session_state

    if (
        state.get(LAST_RUN_MODEL_KEY, None) == model_key
        and state.get(LAST_RUN_SENTENCE, None) == sentence
    ):
        return

    run_model(_model, sentence)
    state[LAST_RUN_MODEL_KEY] = model_key
    state[LAST_RUN_SENTENCE] = sentence


@st.cache_resource(
    hash_funcs={
        TransformerLensTransparentLlm: id
    }
)
def get_contribution_graph(
    model: TransparentLlm,  # TODO bug here
    model_key: str,
    tokens: List[str],
    threshold: float,
) -> nx.Graph:
    """
    The `model_key` and `tokens` are used only for caching. The model itself is not
    hashed, hence the `_` in the beginning.
    """
    return llm_transparency_tool.routes.graph.build_full_graph(
        model,
        B0,
        threshold,
    )


def st_placeholder(
    text: str,
    container=st,
    border: bool = True,
    height: Optional[int] = 500,
):
    empty = container.empty()
    empty.container(border=border, height=height).write(f'<small>{text}</small>', unsafe_allow_html=True)
    return empty