/** * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the license found in the * LICENSE file in the root directory of this source tree. */ import { ComponentProps, Streamlit, withStreamlitConnection, } from 'streamlit-component-lib' import React, { useEffect, useMemo, useRef, useState } from 'react'; import * as d3 from 'd3'; import { Label, Point, } from './common'; import './LlmViewer.css'; export const renderParams = { cellH: 32, cellW: 32, attnSize: 8, afterFfnSize: 8, ffnSize: 6, tokenSelectorSize: 16, layerCornerRadius: 6, } interface Cell { layer: number token: number } enum CellItem { AfterAttn = 'after_attn', AfterFfn = 'after_ffn', Ffn = 'ffn', Original = 'original', // They will only be at level = 0 } interface Node { cell: Cell | null item: CellItem | null } interface NodeProps { node: Node pos: Point isActive: boolean } interface EdgeRaw { weight: number source: string target: string } interface Edge { weight: number from: Node to: Node fromPos: Point toPos: Point isSelectable: boolean isFfn: boolean } interface Selection { node: Node | null edge: Edge | null } function tokenPointerPolygon(origin: Point) { const r = renderParams.tokenSelectorSize / 2 const dy = r / 2 const dx = r * Math.sqrt(3.0) / 2 // Draw an arrow looking down return [ [origin.x, origin.y + r], [origin.x + dx, origin.y - dy], [origin.x - dx, origin.y - dy], ].toString() } function isSameCell(cell1: Cell | null, cell2: Cell | null) { if (cell1 == null || cell2 == null) { return false } return cell1.layer === cell2.layer && cell1.token === cell2.token } function isSameNode(node1: Node | null, node2: Node | null) { if (node1 === null || node2 === null) { return false } return isSameCell(node1.cell, node2.cell) && node1.item === node2.item; } function isSameEdge(edge1: Edge | null, edge2: Edge | null) { if (edge1 === null || edge2 === null) { return false } return isSameNode(edge1.from, edge2.from) && isSameNode(edge1.to, edge2.to); } function nodeFromString(name: string) { const match = name.match(/([AIMX])(\d+)_(\d+)/) if (match == null) { return { cell: null, item: null, } } const [, type, layerStr, tokenStr] = match const layer = +layerStr const token = +tokenStr const typeToCellItem = new Map([ ['A', CellItem.AfterAttn], ['I', CellItem.AfterFfn], ['M', CellItem.Ffn], ['X', CellItem.Original], ]) return { cell: { layer: layer, token: token, }, item: typeToCellItem.get(type) ?? null, } } function isValidNode(node: Node, nLayers: number, nTokens: number) { if (node.cell === null) { return true } return node.cell.layer < nLayers && node.cell.token < nTokens } function isValidSelection(selection: Selection, nLayers: number, nTokens: number) { if (selection.node !== null) { return isValidNode(selection.node, nLayers, nTokens) } if (selection.edge !== null) { return isValidNode(selection.edge.from, nLayers, nTokens) && isValidNode(selection.edge.to, nLayers, nTokens) } return true } const ContributionGraph = ({ args }: ComponentProps) => { const modelInfo = args['model_info'] const tokens = args['tokens'] const edgesRaw: EdgeRaw[][] = args['edges_per_token'] const nLayers = modelInfo === null ? 0 : modelInfo.n_layers const nTokens = tokens === null ? 0 : tokens.length const [selection, setSelection] = useState({ node: null, edge: null, }) var curSelection = selection if (!isValidSelection(selection, nLayers, nTokens)) { curSelection = { node: null, edge: null, } setSelection(curSelection) Streamlit.setComponentValue(curSelection) } const [startToken, setStartToken] = useState(nTokens - 1) // We have startToken state var, but it won't be updated till next render, so use // this var in the current render. var curStartToken = startToken if (startToken >= nTokens) { curStartToken = nTokens - 1 setStartToken(curStartToken) } const handleRepresentationClick = (node: Node) => { const newSelection: Selection = { node: node, edge: null, } setSelection(newSelection) Streamlit.setComponentValue(newSelection) } const handleEdgeClick = (edge: Edge) => { if (!edge.isSelectable) { return } const newSelection: Selection = { node: edge.to, edge: edge, } setSelection(newSelection) Streamlit.setComponentValue(newSelection) } const handleTokenClick = (t: number) => { setStartToken(t) } const [xScale, yScale] = useMemo(() => { const x = d3.scaleLinear() .domain([-2, nTokens - 1]) .range([0, renderParams.cellW * (nTokens + 2)]) const y = d3.scaleLinear() .domain([-1, nLayers]) .range([renderParams.cellH * (nLayers + 2), 0]) return [x, y] }, [nLayers, nTokens]) const cells = useMemo(() => { let result: Cell[] = [] for (let l = 0; l < nLayers; l++) { for (let t = 0; t < nTokens; t++) { result.push({ layer: l, token: t, }) } } return result }, [nLayers, nTokens]) const nodeCoords = useMemo(() => { let result = new Map() const w = renderParams.cellW const h = renderParams.cellH for (var cell of cells) { const cx = xScale(cell.token + 0.5) const cy = yScale(cell.layer - 0.5) result.set( JSON.stringify({ cell: cell, item: CellItem.AfterAttn }), { x: cx, y: cy + h / 4 }, ) result.set( JSON.stringify({ cell: cell, item: CellItem.AfterFfn }), { x: cx, y: cy - h / 4 }, ) result.set( JSON.stringify({ cell: cell, item: CellItem.Ffn }), { x: cx + 5 * w / 16, y: cy }, ) } for (let t = 0; t < nTokens; t++) { cell = { layer: 0, token: t, } const cx = xScale(cell.token + 0.5) const cy = yScale(cell.layer - 1.0) result.set( JSON.stringify({ cell: cell, item: CellItem.Original }), { x: cx, y: cy + h / 4 }, ) } return result }, [cells, nTokens, xScale, yScale]) const edges: Edge[][] = useMemo(() => { let result = [] for (var edgeList of edgesRaw) { let edgesPerStartToken = [] for (var edge of edgeList) { const u = nodeFromString(edge.source) const v = nodeFromString(edge.target) var isSelectable = ( u.cell !== null && v.cell !== null && v.item === CellItem.AfterAttn ) var isFfn = ( u.cell !== null && v.cell !== null && ( u.item === CellItem.Ffn || v.item === CellItem.Ffn ) ) edgesPerStartToken.push({ weight: edge.weight, from: u, to: v, fromPos: nodeCoords.get(JSON.stringify(u)) ?? { 'x': 0, 'y': 0 }, toPos: nodeCoords.get(JSON.stringify(v)) ?? { 'x': 0, 'y': 0 }, isSelectable: isSelectable, isFfn: isFfn, }) } result.push(edgesPerStartToken) } return result }, [edgesRaw, nodeCoords]) const activeNodes = useMemo(() => { let result = new Set() for (var edge of edges[curStartToken]) { const u = JSON.stringify(edge.from) const v = JSON.stringify(edge.to) result.add(u) result.add(v) } return result }, [edges, curStartToken]) const nodeProps = useMemo(() => { let result: Array = [] nodeCoords.forEach((p: Point, node: string) => { result.push({ node: JSON.parse(node), pos: p, isActive: activeNodes.has(node), }) }) return result }, [nodeCoords, activeNodes]) const tokenLabels: Label[] = useMemo(() => { if (!tokens) { return [] } return tokens.map((s: string, i: number) => ({ text: s.replace(/ /g, 'ยท'), pos: { x: xScale(i + 0.5), y: yScale(-1.5), }, })) }, [tokens, xScale, yScale]) const layerLabels: Label[] = useMemo(() => { return Array.from(Array(nLayers).keys()).map(i => ({ text: 'L' + i, pos: { x: xScale(-0.25), y: yScale(i - 0.5), }, })) }, [nLayers, xScale, yScale]) const tokenSelectors: Array<[number, Point]> = useMemo(() => { return Array.from(Array(nTokens).keys()).map(i => ([ i, { x: xScale(i + 0.5), y: yScale(nLayers - 0.5), } ])) }, [nTokens, nLayers, xScale, yScale]) const totalW = xScale(nTokens + 2) const totalH = yScale(-4) useEffect(() => { Streamlit.setFrameHeight(totalH) }, [totalH]) const colorScale = d3.scaleLinear( [0.0, 0.5, 1.0], ['#9eba66', 'darkolivegreen', 'darkolivegreen'] ) const ffnEdgeColorScale = d3.scaleLinear( [0.0, 0.5, 1.0], ['orchid', 'purple', 'purple'] ) const edgeWidthScale = d3.scaleLinear([0.0, 0.5, 1.0], [2.0, 3.0, 3.0]) const svgRef = useRef(null); useEffect(() => { const getNodeStyle = (p: NodeProps, type: string) => { if (isSameNode(p.node, curSelection.node)) { return 'selectable-item selection' } if (p.isActive) { return 'selectable-item active-' + type + '-node' } return 'selectable-item inactive-node' } const svg = d3.select(svgRef.current) svg.selectAll('*').remove() svg .selectAll('layers') .data(Array.from(Array(nLayers).keys()).filter((x) => x % 2 === 1)) .enter() .append('rect') .attr('class', 'layer-highlight') .attr('x', xScale(-1.0)) .attr('y', (layer) => yScale(layer)) .attr('width', xScale(nTokens + 0.25) - xScale(-1.0)) .attr('height', (layer) => yScale(layer) - yScale(layer + 1)) .attr('rx', renderParams.layerCornerRadius) svg .selectAll('edges') .data(edges[curStartToken]) .enter() .append('line') .style('stroke', (edge: Edge) => { if (isSameEdge(edge, curSelection.edge)) { return 'orange' } if (edge.isFfn) { return ffnEdgeColorScale(edge.weight) } return colorScale(edge.weight) }) .attr('class', (edge: Edge) => edge.isSelectable ? 'selectable-edge' : '') .style('stroke-width', (edge: Edge) => edgeWidthScale(edge.weight)) .attr('x1', (edge: Edge) => edge.fromPos.x) .attr('y1', (edge: Edge) => edge.fromPos.y) .attr('x2', (edge: Edge) => edge.toPos.x) .attr('y2', (edge: Edge) => edge.toPos.y) .on('click', (event: PointerEvent, edge) => { handleEdgeClick(edge) }) svg .selectAll('residual') .data(nodeProps) .enter() .filter((p) => { return p.node.item === CellItem.AfterAttn || p.node.item === CellItem.AfterFfn }) .append('circle') .attr('class', (p) => getNodeStyle(p, 'residual')) .attr('cx', (p) => p.pos.x) .attr('cy', (p) => p.pos.y) .attr('r', renderParams.attnSize / 2) .on('click', (event: PointerEvent, p) => { handleRepresentationClick(p.node) }) svg .selectAll('ffn') .data(nodeProps) .enter() .filter((p) => p.node.item === CellItem.Ffn && p.isActive) .append('rect') .attr('class', (p) => getNodeStyle(p, 'ffn')) .attr('x', (p) => p.pos.x - renderParams.ffnSize / 2) .attr('y', (p) => p.pos.y - renderParams.ffnSize / 2) .attr('width', renderParams.ffnSize) .attr('height', renderParams.ffnSize) .on('click', (event: PointerEvent, p) => { handleRepresentationClick(p.node) }) svg .selectAll('token_labels') .data(tokenLabels) .enter() .append('text') .attr('x', (label: Label) => label.pos.x) .attr('y', (label: Label) => label.pos.y) .attr('text-anchor', 'end') .attr('dominant-baseline', 'middle') .attr('alignment-baseline', 'top') .attr('transform', (label: Label) => 'rotate(-40, ' + label.pos.x + ', ' + label.pos.y + ')') .text((label: Label) => label.text) svg .selectAll('layer_labels') .data(layerLabels) .enter() .append('text') .attr('x', (label: Label) => label.pos.x) .attr('y', (label: Label) => label.pos.y) .attr('text-anchor', 'middle') .attr('alignment-baseline', 'middle') .text((label: Label) => label.text) svg .selectAll('token_selectors') .data(tokenSelectors) .enter() .append('polygon') .attr('class', ([i,]) => ( curStartToken === i ? 'selectable-item selection' : 'selectable-item token-selector' )) .attr('points', ([, p]) => tokenPointerPolygon(p)) .attr('r', renderParams.tokenSelectorSize / 2) .on('click', (event: PointerEvent, [i,]) => { handleTokenClick(i) }) }, [ cells, edges, nodeProps, tokenLabels, layerLabels, tokenSelectors, curStartToken, curSelection, colorScale, ffnEdgeColorScale, edgeWidthScale, nLayers, nTokens, xScale, yScale ]) return } export default withStreamlitConnection(ContributionGraph)