mahnerak's picture
Initial Commit πŸš€
ce00289
raw
history blame
No virus
15.5 kB
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
import {
ComponentProps,
Streamlit,
withStreamlitConnection,
} from 'streamlit-component-lib'
import React, { useEffect, useMemo, useRef, useState } from 'react';
import * as d3 from 'd3';
import {
Label,
Point,
} from './common';
import './LlmViewer.css';
export const renderParams = {
cellH: 32,
cellW: 32,
attnSize: 8,
afterFfnSize: 8,
ffnSize: 6,
tokenSelectorSize: 16,
layerCornerRadius: 6,
}
interface Cell {
layer: number
token: number
}
enum CellItem {
AfterAttn = 'after_attn',
AfterFfn = 'after_ffn',
Ffn = 'ffn',
Original = 'original', // They will only be at level = 0
}
interface Node {
cell: Cell | null
item: CellItem | null
}
interface NodeProps {
node: Node
pos: Point
isActive: boolean
}
interface EdgeRaw {
weight: number
source: string
target: string
}
interface Edge {
weight: number
from: Node
to: Node
fromPos: Point
toPos: Point
isSelectable: boolean
isFfn: boolean
}
interface Selection {
node: Node | null
edge: Edge | null
}
function tokenPointerPolygon(origin: Point) {
const r = renderParams.tokenSelectorSize / 2
const dy = r / 2
const dx = r * Math.sqrt(3.0) / 2
// Draw an arrow looking down
return [
[origin.x, origin.y + r],
[origin.x + dx, origin.y - dy],
[origin.x - dx, origin.y - dy],
].toString()
}
function isSameCell(cell1: Cell | null, cell2: Cell | null) {
if (cell1 == null || cell2 == null) {
return false
}
return cell1.layer === cell2.layer && cell1.token === cell2.token
}
function isSameNode(node1: Node | null, node2: Node | null) {
if (node1 === null || node2 === null) {
return false
}
return isSameCell(node1.cell, node2.cell)
&& node1.item === node2.item;
}
function isSameEdge(edge1: Edge | null, edge2: Edge | null) {
if (edge1 === null || edge2 === null) {
return false
}
return isSameNode(edge1.from, edge2.from) && isSameNode(edge1.to, edge2.to);
}
function nodeFromString(name: string) {
const match = name.match(/([AIMX])(\d+)_(\d+)/)
if (match == null) {
return {
cell: null,
item: null,
}
}
const [, type, layerStr, tokenStr] = match
const layer = +layerStr
const token = +tokenStr
const typeToCellItem = new Map<string, CellItem>([
['A', CellItem.AfterAttn],
['I', CellItem.AfterFfn],
['M', CellItem.Ffn],
['X', CellItem.Original],
])
return {
cell: {
layer: layer,
token: token,
},
item: typeToCellItem.get(type) ?? null,
}
}
function isValidNode(node: Node, nLayers: number, nTokens: number) {
if (node.cell === null) {
return true
}
return node.cell.layer < nLayers && node.cell.token < nTokens
}
function isValidSelection(selection: Selection, nLayers: number, nTokens: number) {
if (selection.node !== null) {
return isValidNode(selection.node, nLayers, nTokens)
}
if (selection.edge !== null) {
return isValidNode(selection.edge.from, nLayers, nTokens) &&
isValidNode(selection.edge.to, nLayers, nTokens)
}
return true
}
const ContributionGraph = ({ args }: ComponentProps) => {
const modelInfo = args['model_info']
const tokens = args['tokens']
const edgesRaw: EdgeRaw[][] = args['edges_per_token']
const nLayers = modelInfo === null ? 0 : modelInfo.n_layers
const nTokens = tokens === null ? 0 : tokens.length
const [selection, setSelection] = useState<Selection>({
node: null,
edge: null,
})
var curSelection = selection
if (!isValidSelection(selection, nLayers, nTokens)) {
curSelection = {
node: null,
edge: null,
}
setSelection(curSelection)
Streamlit.setComponentValue(curSelection)
}
const [startToken, setStartToken] = useState<number>(nTokens - 1)
// We have startToken state var, but it won't be updated till next render, so use
// this var in the current render.
var curStartToken = startToken
if (startToken >= nTokens) {
curStartToken = nTokens - 1
setStartToken(curStartToken)
}
const handleRepresentationClick = (node: Node) => {
const newSelection: Selection = {
node: node,
edge: null,
}
setSelection(newSelection)
Streamlit.setComponentValue(newSelection)
}
const handleEdgeClick = (edge: Edge) => {
if (!edge.isSelectable) {
return
}
const newSelection: Selection = {
node: edge.to,
edge: edge,
}
setSelection(newSelection)
Streamlit.setComponentValue(newSelection)
}
const handleTokenClick = (t: number) => {
setStartToken(t)
}
const [xScale, yScale] = useMemo(() => {
const x = d3.scaleLinear()
.domain([-2, nTokens - 1])
.range([0, renderParams.cellW * (nTokens + 2)])
const y = d3.scaleLinear()
.domain([-1, nLayers])
.range([renderParams.cellH * (nLayers + 2), 0])
return [x, y]
}, [nLayers, nTokens])
const cells = useMemo(() => {
let result: Cell[] = []
for (let l = 0; l < nLayers; l++) {
for (let t = 0; t < nTokens; t++) {
result.push({
layer: l,
token: t,
})
}
}
return result
}, [nLayers, nTokens])
const nodeCoords = useMemo(() => {
let result = new Map<string, Point>()
const w = renderParams.cellW
const h = renderParams.cellH
for (var cell of cells) {
const cx = xScale(cell.token + 0.5)
const cy = yScale(cell.layer - 0.5)
result.set(
JSON.stringify({ cell: cell, item: CellItem.AfterAttn }),
{ x: cx, y: cy + h / 4 },
)
result.set(
JSON.stringify({ cell: cell, item: CellItem.AfterFfn }),
{ x: cx, y: cy - h / 4 },
)
result.set(
JSON.stringify({ cell: cell, item: CellItem.Ffn }),
{ x: cx + 5 * w / 16, y: cy },
)
}
for (let t = 0; t < nTokens; t++) {
cell = {
layer: 0,
token: t,
}
const cx = xScale(cell.token + 0.5)
const cy = yScale(cell.layer - 1.0)
result.set(
JSON.stringify({ cell: cell, item: CellItem.Original }),
{ x: cx, y: cy + h / 4 },
)
}
return result
}, [cells, nTokens, xScale, yScale])
const edges: Edge[][] = useMemo(() => {
let result = []
for (var edgeList of edgesRaw) {
let edgesPerStartToken = []
for (var edge of edgeList) {
const u = nodeFromString(edge.source)
const v = nodeFromString(edge.target)
var isSelectable = (
u.cell !== null && v.cell !== null && v.item === CellItem.AfterAttn
)
var isFfn = (
u.cell !== null && v.cell !== null && (
u.item === CellItem.Ffn || v.item === CellItem.Ffn
)
)
edgesPerStartToken.push({
weight: edge.weight,
from: u,
to: v,
fromPos: nodeCoords.get(JSON.stringify(u)) ?? { 'x': 0, 'y': 0 },
toPos: nodeCoords.get(JSON.stringify(v)) ?? { 'x': 0, 'y': 0 },
isSelectable: isSelectable,
isFfn: isFfn,
})
}
result.push(edgesPerStartToken)
}
return result
}, [edgesRaw, nodeCoords])
const activeNodes = useMemo(() => {
let result = new Set<string>()
for (var edge of edges[curStartToken]) {
const u = JSON.stringify(edge.from)
const v = JSON.stringify(edge.to)
result.add(u)
result.add(v)
}
return result
}, [edges, curStartToken])
const nodeProps = useMemo(() => {
let result: Array<NodeProps> = []
nodeCoords.forEach((p: Point, node: string) => {
result.push({
node: JSON.parse(node),
pos: p,
isActive: activeNodes.has(node),
})
})
return result
}, [nodeCoords, activeNodes])
const tokenLabels: Label[] = useMemo(() => {
if (!tokens) {
return []
}
return tokens.map((s: string, i: number) => ({
text: s.replace(/ /g, 'Β·'),
pos: {
x: xScale(i + 0.5),
y: yScale(-1.5),
},
}))
}, [tokens, xScale, yScale])
const layerLabels: Label[] = useMemo(() => {
return Array.from(Array(nLayers).keys()).map(i => ({
text: 'L' + i,
pos: {
x: xScale(-0.25),
y: yScale(i - 0.5),
},
}))
}, [nLayers, xScale, yScale])
const tokenSelectors: Array<[number, Point]> = useMemo(() => {
return Array.from(Array(nTokens).keys()).map(i => ([
i,
{
x: xScale(i + 0.5),
y: yScale(nLayers - 0.5),
}
]))
}, [nTokens, nLayers, xScale, yScale])
const totalW = xScale(nTokens + 2)
const totalH = yScale(-4)
useEffect(() => {
Streamlit.setFrameHeight(totalH)
}, [totalH])
const colorScale = d3.scaleLinear(
[0.0, 0.5, 1.0],
['#9eba66', 'darkolivegreen', 'darkolivegreen']
)
const ffnEdgeColorScale = d3.scaleLinear(
[0.0, 0.5, 1.0],
['orchid', 'purple', 'purple']
)
const edgeWidthScale = d3.scaleLinear([0.0, 0.5, 1.0], [2.0, 3.0, 3.0])
const svgRef = useRef(null);
useEffect(() => {
const getNodeStyle = (p: NodeProps, type: string) => {
if (isSameNode(p.node, curSelection.node)) {
return 'selectable-item selection'
}
if (p.isActive) {
return 'selectable-item active-' + type + '-node'
}
return 'selectable-item inactive-node'
}
const svg = d3.select(svgRef.current)
svg.selectAll('*').remove()
svg
.selectAll('layers')
.data(Array.from(Array(nLayers).keys()).filter((x) => x % 2 === 1))
.enter()
.append('rect')
.attr('class', 'layer-highlight')
.attr('x', xScale(-1.0))
.attr('y', (layer) => yScale(layer))
.attr('width', xScale(nTokens + 0.25) - xScale(-1.0))
.attr('height', (layer) => yScale(layer) - yScale(layer + 1))
.attr('rx', renderParams.layerCornerRadius)
svg
.selectAll('edges')
.data(edges[curStartToken])
.enter()
.append('line')
.style('stroke', (edge: Edge) => {
if (isSameEdge(edge, curSelection.edge)) {
return 'orange'
}
if (edge.isFfn) {
return ffnEdgeColorScale(edge.weight)
}
return colorScale(edge.weight)
})
.attr('class', (edge: Edge) => edge.isSelectable ? 'selectable-edge' : '')
.style('stroke-width', (edge: Edge) => edgeWidthScale(edge.weight))
.attr('x1', (edge: Edge) => edge.fromPos.x)
.attr('y1', (edge: Edge) => edge.fromPos.y)
.attr('x2', (edge: Edge) => edge.toPos.x)
.attr('y2', (edge: Edge) => edge.toPos.y)
.on('click', (event: PointerEvent, edge) => {
handleEdgeClick(edge)
})
svg
.selectAll('residual')
.data(nodeProps)
.enter()
.filter((p) => {
return p.node.item === CellItem.AfterAttn
|| p.node.item === CellItem.AfterFfn
})
.append('circle')
.attr('class', (p) => getNodeStyle(p, 'residual'))
.attr('cx', (p) => p.pos.x)
.attr('cy', (p) => p.pos.y)
.attr('r', renderParams.attnSize / 2)
.on('click', (event: PointerEvent, p) => {
handleRepresentationClick(p.node)
})
svg
.selectAll('ffn')
.data(nodeProps)
.enter()
.filter((p) => p.node.item === CellItem.Ffn && p.isActive)
.append('rect')
.attr('class', (p) => getNodeStyle(p, 'ffn'))
.attr('x', (p) => p.pos.x - renderParams.ffnSize / 2)
.attr('y', (p) => p.pos.y - renderParams.ffnSize / 2)
.attr('width', renderParams.ffnSize)
.attr('height', renderParams.ffnSize)
.on('click', (event: PointerEvent, p) => {
handleRepresentationClick(p.node)
})
svg
.selectAll('token_labels')
.data(tokenLabels)
.enter()
.append('text')
.attr('x', (label: Label) => label.pos.x)
.attr('y', (label: Label) => label.pos.y)
.attr('text-anchor', 'end')
.attr('dominant-baseline', 'middle')
.attr('alignment-baseline', 'top')
.attr('transform', (label: Label) =>
'rotate(-40, ' + label.pos.x + ', ' + label.pos.y + ')')
.text((label: Label) => label.text)
svg
.selectAll('layer_labels')
.data(layerLabels)
.enter()
.append('text')
.attr('x', (label: Label) => label.pos.x)
.attr('y', (label: Label) => label.pos.y)
.attr('text-anchor', 'middle')
.attr('alignment-baseline', 'middle')
.text((label: Label) => label.text)
svg
.selectAll('token_selectors')
.data(tokenSelectors)
.enter()
.append('polygon')
.attr('class', ([i,]) => (
curStartToken === i
? 'selectable-item selection'
: 'selectable-item token-selector'
))
.attr('points', ([, p]) => tokenPointerPolygon(p))
.attr('r', renderParams.tokenSelectorSize / 2)
.on('click', (event: PointerEvent, [i,]) => {
handleTokenClick(i)
})
}, [
cells,
edges,
nodeProps,
tokenLabels,
layerLabels,
tokenSelectors,
curStartToken,
curSelection,
colorScale,
ffnEdgeColorScale,
edgeWidthScale,
nLayers,
nTokens,
xScale,
yScale
])
return <svg ref={svgRef} width={totalW} height={totalH}></svg>
}
export default withStreamlitConnection(ContributionGraph)