Spaces:
Sleeping
Sleeping
async () => { | |
// set testFn() function on globalThis, so you html onlclick can access it | |
globalThis.testFn = () => { | |
document.getElementById('demo').innerHTML = "Hello?" | |
}; | |
const d37 = await import("https://cdn.jsdelivr.net/npm/d3@7/+esm"); | |
const d3 = await import("https://cdn.jsdelivr.net/npm/d3@5/+esm"); | |
const $ = await import("https://cdn.jsdelivr.net/npm/jquery@3.7.1/dist/jquery.min.js"); | |
globalThis.$ = $; | |
globalThis.d3 = d3; | |
globalThis.d3Fn = () => { | |
d3.select('#viz').append('svg') | |
.append('rect') | |
.attr('width', 50) | |
.attr('height', 50) | |
.attr('fill', 'black') | |
.on('mouseover', function(){d3.select(this).attr('fill', 'red')}) | |
.on('mouseout', function(){d3.select(this).attr('fill', 'black')}); | |
}; | |
globalThis.testFn_out = (val,radio_c) => { | |
// document.getElementById('demo').innerHTML = val | |
console.log(val); | |
// globalThis.d3Fn(); | |
return([val,radio_c]); | |
}; | |
globalThis.testFn_out_json = (data) => { | |
console.log(data); | |
var $ = jQuery; | |
data_beam = data[1][0]; | |
data_probs = data[1][1]; | |
data_html_inputs = data[1][2]; | |
data_html_target = data[1][3]; | |
data_embds = data[2]; | |
attViz(data[3]); | |
attViz(data[4]); | |
attViz(data[5]); | |
console.log(data_beam, ) | |
const idMapping = data_beam.reduce((acc, el, i) => { | |
acc[el.id] = i; | |
return acc; | |
}, {}); | |
let root; | |
data_beam.forEach(el => { | |
// Handle the root element | |
if (el.parentId === null) { | |
root = el; | |
return; | |
} | |
// Use our mapping to locate the parent element in our data_beam array | |
const parentEl = data_beam[idMapping[el.parentId]]; | |
// Add our current el to its parent's `children` array | |
parentEl.children = [...(parentEl.children || []), el]; | |
}); | |
// console.log(Tree(root)); | |
// document.getElementById('d3_beam_search').innerHTML = Tree(root) | |
d3.select('#d3_beam_search').html(""); | |
d3.select('#d3_beam_search').append(function(){return Tree(root);}); | |
//probabilities; | |
// | |
d3.select('#d3_text_grid').html(""); | |
d3.select('#d3_text_grid').append(function(){return TextGrid(data_probs);}); | |
// $('#d3_text_grid').html(TextGrid(data)) ; | |
//tokenization; | |
d3.select('#d3_tok').html(data_html_inputs); | |
d3.select('#d3_tok_target').html(data_html_target); | |
//embeddings | |
d3.select("#d3_embeds_source").html("here"); | |
// words or token visualization ? | |
console.log(d3.select("#select_type").node().value); | |
d3.select("#select_type").attr("hidden", null); | |
d3.select("#select_type").on("change", change); | |
change(); | |
// tokens | |
// network plots; | |
['input', 'output'].forEach(text_type => { | |
['tokens', 'words'].forEach(text_key => { | |
// console.log(type, key, data[0][text_type]); | |
data_i = data_embds[text_type][text_key]; | |
embeddings_network([], data_i['tnse'], data_i['similar_queries'], type=text_type +"_"+text_key, ) | |
}); | |
}); | |
// $('#d3_beam_search').html(Tree(root)) ; | |
return(['string', {}]) | |
} | |
function change() { | |
show_type = d3.select("#select_type").node().value; | |
// hide all | |
d3.selectAll(".d3_embed").attr("hidden",''); | |
d3.selectAll(".d3_graph").attr("hidden", ''); | |
// show current type; | |
d3.select("#d3_embeds_input_" + show_type).attr("hidden", null); | |
d3.select("#d3_embeds_output_" + show_type).attr("hidden", null); | |
d3.select("#d3_graph_input_" + show_type).attr("hidden", null); | |
d3.select("#d3_graph_output_" + show_type).attr("hidden", null); | |
} | |
function embeddings_network(tokens_text, dict_projected_embds, similar_vocab_queries, type="source", ){ | |
// tokens_text : not used; | |
// dict_projected_embds = tnse | |
console.log("Each token is a node; distance if in similar list", type ); | |
console.log(tokens_text, dict_projected_embds, similar_vocab_queries); | |
// similar_vocab_queries_target[key]['similar_topk'] | |
var nodes_tokens = {} | |
var nodeHash = {}; | |
var nodes = []; // [{id: , label: }] | |
var edges = []; // [{source: , target: weight: }] | |
var edges_ids = []; // [{source: , target: weight: }] | |
// similar_vocab_queries {key: {similar_topk : [], distance : []}} | |
console.log('similar_vocab_queries', similar_vocab_queries); | |
prev_node = ''; | |
for ([sent_token, value] of Object.entries(similar_vocab_queries)) { | |
// console.log('dict_projected_embds',sent_token, parseInt(sent_token), value, dict_projected_embds); | |
// sent_token = parseInt(sent_token); // Object.entries assumes key:string; | |
token_text = dict_projected_embds[sent_token][3] | |
if (!nodeHash[sent_token]) { | |
nodeHash[sent_token] = {id: sent_token, label: token_text, type: 'sentence', type_i: 0}; | |
nodes.push(nodeHash[sent_token]); | |
} | |
sim_tokens = value['similar_topk'] | |
dist_tokens = value['distance'] | |
for (let index = 0; index < sim_tokens.length; index++) { | |
const sim = sim_tokens[index]; | |
const dist = dist_tokens[index]; | |
token_text_sim = dict_projected_embds[sim][3] | |
if (!nodeHash[sim]) { | |
nodeHash[sim] = {id: sim, label: token_text_sim, type:'similar', type_i: 1}; | |
nodes.push(nodeHash[sim]); | |
} | |
edges.push({source: nodeHash[sent_token], target: nodeHash[sim], weight: dist}); | |
edges_ids.push({source: sent_token, target: sim, weight: dist}); | |
} | |
if (prev_node != '' ) { | |
edges.push({source: nodeHash[prev_node], target:nodeHash[sent_token], weight: 1}); | |
edges_ids.push({source: prev_node, target: sent_token, weight: 1}); | |
} | |
prev_node = sent_token; | |
} | |
console.log("TYPE", type, edges, nodes, edges_ids, similar_vocab_queries) | |
// d3.select('#d3_graph_input_tokens').html(networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, div_type=type) ); | |
// type +"_"+key | |
d3.select('#d3_graph_'+type).html(""); | |
d3.select('#d3_graph_'+type).append(function(){return networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, dict_projected_embds,div_type=type);}); | |
// $('#d3_embeds_network_target').html(networkPlot({nodes: nodes, links:edges})); | |
// $('#d3_embeds_network_'+type).html(etworkPlot({nodes: nodes, link:edges})); | |
} | |
function networkPlot(data, similar_vocab_queries,dict_proj, div_type="source", { | |
width = 400, // outer width, in pixels | |
height , // outer height, in pixels | |
r = 3, // radius of nodes | |
padding = 1, // horizontal padding for first and last column | |
// text = d => d[2], | |
} = {}){ | |
// data_dict = data; | |
data = data// [div_type]; | |
similar_vocab_queries = similar_vocab_queries// [div_type]; | |
console.log("data, similar_vocab_queries, div_type"); | |
console.log(data, similar_vocab_queries, div_type); | |
// Create the SVG container. | |
var margin = {top: 10, right: 10, bottom: 30, left: 50 }, | |
width = width //- margin.left - margin.right, | |
height = 400 //- margin.top - margin.bottom; | |
width_box = width + margin.left + margin.right; | |
height_box = height + margin.top + margin.bottom | |
totalWidth = width*2; | |
var svg = d37.create("svg") | |
.attr("width", width + margin.left + margin.right) | |
.attr("height", height + margin.top + margin.bottom) | |
// Initialize the links | |
var link = svg | |
.selectAll("line") | |
.data(data.links) | |
.enter() | |
.append("line") | |
.style("fill", d => d.weight == 1 ? "#dfd5d5" : "#000000") // , "#69b3a2" : "#69b3a2") | |
.style("stroke", "#aaa") | |
var text = svg | |
.selectAll("text") | |
.data(data.nodes) | |
.enter() | |
.append("text") | |
.style("text-anchor", "middle") | |
.attr("y", 15) | |
.attr("class", d => 'text_token-'+ dict_proj[d.id][4] + div_type) | |
.attr("div-type", div_type) | |
// .attr("class", d => 'text_token-'+ d.index) | |
.text(function (d) {return d.label} ) | |
// .on('mouseover', function(d) { (d.type_i == 0) ? highlight_mouseover_text : console.log(0)} ) | |
// .on('mouseover', function(d) { (d.type_i == 0) ? highlight_mouseout_text : '' } ) | |
// .on('mouseout', highlight_mouseout_text ) | |
// .join('text') | |
// .text(function(d) { | |
// return d.id | |
// }) | |
// Initialize the nodes | |
var node = svg | |
.selectAll("circle") | |
.data(data.nodes) | |
.enter() | |
.append("circle") | |
.attr("r", 6) | |
// .attr("class", d => 'node_token-'+ d.id) | |
.attr("class", d => 'node_token-'+ dict_proj[d.id][4] + div_type) | |
.attr("div-type", div_type) | |
.style("fill", d => d.type_i ? "#e85252" : "#6689c6") // , "#69b3a2" : "#69b3a2") | |
.on('mouseover', highlight_mouseover ) | |
// .on('mouseover', function(d) { return (d.type_i == 0) ? highlight_mouseover : console.log(0)} ) | |
.on('mouseout',highlight_mouseout ) | |
.on('click', change_legend ) | |
// .on('click', show_similar_tokens ) | |
// Let's list the force we wanna apply on the network | |
var simulation = d37.forceSimulation(data.nodes) // Force algorithm is applied to data.nodes | |
.force("link", d37.forceLink() // This force provides links between nodes | |
.id(function(d) { return d.id; }) // This provide the id of a node | |
.links(data.links) // and this the list of links | |
) | |
.force("charge", d37.forceManyBody(-400)) // This adds repulsion between nodes. Play with the -400 for the repulsion strength | |
.force("center", d37.forceCenter(width / 2, height / 2)) // This force attracts nodes to the center of the svg area | |
// .force("collision", d3.forceCollide()) | |
.on("end", ticked); | |
// This function is run at each iteration of the force algorithm, updating the nodes position. | |
function ticked() { | |
link | |
.attr("x1", function(d) { return d.source.x; }) | |
.attr("y1", function(d) { return d.source.y; }) | |
.attr("x2", function(d) { return d.target.x; }) | |
.attr("y2", function(d) { return d.target.y; }); | |
node | |
.attr("cx", function (d) { return d.x+3; }) | |
.attr("cy", function(d) { return d.y-3; }); | |
text | |
.attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; }) | |
} | |
function highlight_mouseover(d,i) { | |
console.log("highlight_mouseover", d,i, d37.select(this).attr("div-type")); | |
if (i.type_i == 0 ){ | |
token_id = i.id | |
similar_ids = similar_vocab_queries[token_id]['similar_topk']; | |
d37.select(this).transition() | |
.duration('50') | |
.style('opacity', '1') | |
.attr("r", 12) | |
type = d37.select(this).attr("div-type") | |
similar_ids.forEach(similar_token => { | |
node_id_name = dict_proj[similar_token][4] | |
d37.selectAll('.node_token-'+ node_id_name + type).attr("r",12 ).style('opacity', '1')//.raise() | |
// d3.selectAll('.text_token-'+ node_id_name).raise() | |
}); | |
} | |
} | |
function highlight_mouseout(d,i) { | |
if (i.type_i == 0 ){ | |
token_id = i.id | |
console.log("similar_vocab_queries", similar_vocab_queries, "this type:", d37.select(this).attr("div-type")); | |
similar_ids = similar_vocab_queries[token_id]['similar_topk']; | |
// clean_sentences(); | |
d37.select(this).transition() | |
.duration('50') | |
.style('opacity', '.7') | |
.attr("r", 6) | |
type = d37.select(this).attr("div-type") | |
similar_ids.forEach(similar_token => { | |
node_id_name = dict_proj[similar_token][4] | |
d37.selectAll('.node_token-' + node_id_name + type).attr("r",6 ).style('opacity', '.7') | |
d37.selectAll("circle").raise() | |
}); | |
} | |
} | |
function change_legend(d,i,j) { | |
console.log(d,i,dict_proj); | |
if (i['id'] in dict_proj){ | |
// show_sentences(dict_proj[i[2]], i[2]); | |
show_similar_tokens(i['id'], '#similar_'+type); | |
console.log(dict_proj[i['id']]); | |
} | |
else{console.log("no sentence")}; | |
} | |
function show_similar_tokens(token, div_name_similar='#similar_input_tokens') { | |
d37.select(div_name_similar).html(""); | |
console.log("token", token); | |
console.log("similar_vocab_queries[token]", similar_vocab_queries[token]); | |
token_data = similar_vocab_queries[token]; | |
console.log(token, token_data); | |
var decForm = d37.format(".3f"); | |
d37.select(div_name_similar) | |
.selectAll().append("p") | |
.data(token_data['similar_topk']) | |
.enter() | |
.append("p").append('text') | |
// .attr('class_data', sent_id) | |
.attr('class_id', d => d) | |
.style("background", d=> {if (d == token) return "yellow"} ) | |
// .text( d => d + " \n "); | |
.text((d,i) => do_text(d,i) ); | |
function do_text(d,i){ | |
console.log("do_text d,i" ); | |
console.log(d,i); | |
console.log("data_dict[d], data_dict"); | |
console.log(dict_proj[d], dict_proj); | |
return dict_proj[d][3] + " " + decForm(token_data['distance'][i]) + " "; | |
} | |
} | |
return svg.node(); | |
}; | |
// Copyright 2021 Observable, Inc. | |
// Released under the ISC license. | |
// https://observablehq.com/@d3/tree | |
function Tree(data, { // data is either tabular (array of objects) or hierarchy (nested objects) | |
path, // as an alternative to id and parentId, returns an array identifier, imputing internal nodes | |
id = Array.isArray(data) ? d => d.id : null, // if tabular data, given a d in data, returns a unique identifier (string) | |
parentId = Array.isArray(data) ? d => d.parentId : null, // if tabular data, given a node d, returns its parent’s identifier | |
children, // if hierarchical data, given a d in data, returns its children | |
tree = d3.tree, // layout algorithm (typically d3.tree or d3.cluster) | |
sort, // how to sort nodes prior to layout (e.g., (a, b) => d3.descending(a.height, b.height)) | |
label = d => d.name, // given a node d, returns the display name | |
title = d => d.name, // given a node d, returns its hover text | |
link , // given a node d, its link (if any) | |
linkTarget = "_blank", // the target attribute for links (if any) | |
width = 800, // outer width, in pixels | |
height, // outer height, in pixels | |
r = 3, // radius of nodes | |
padding = 1, // horizontal padding for first and last column | |
fill = "#999", // fill for nodes | |
fillOpacity, // fill opacity for nodes | |
stroke = "#555", // stroke for links | |
strokeWidth = 2, // stroke width for links | |
strokeOpacity = 0.4, // stroke opacity for links | |
strokeLinejoin, // stroke line join for links | |
strokeLinecap, // stroke line cap for links | |
halo = "#fff", // color of label halo | |
haloWidth = 3, // padding around the labels | |
curve = d37.curveBumpX, // curve for the link | |
} = {}) { | |
// If id and parentId options are specified, or the path option, use d3.stratify | |
// to convert tabular data to a hierarchy; otherwise we assume that the data is | |
// specified as an object {children} with nested objects (a.k.a. the “flare.json” | |
// format), and use d3.hierarchy. | |
const root = path != null ? d3.stratify().path(path)(data) | |
: id != null || parentId != null ? d3.stratify().id(id).parentId(parentId)(data) | |
: d3.hierarchy(data, children); | |
// Sort the nodes. | |
if (sort != null) root.sort(sort); | |
// Compute labels and titles. | |
const descendants = root.descendants(); | |
const L = label == null ? null : descendants.map(d => label(d.data, d)); | |
// Compute the layout. | |
const descWidth = 10; | |
// console.log('descendants', descendants); | |
const realWidth = descWidth * descendants.length | |
const totalWidth = (realWidth > width) ? realWidth : width; | |
const dx = 25; | |
const dy = totalWidth / (root.height + padding); | |
tree().nodeSize([dx, dy])(root); | |
// Center the tree. | |
let x0 = Infinity; | |
let x1 = -x0; | |
root.each(d => { | |
if (d.x > x1) x1 = d.x; | |
if (d.x < x0) x0 = d.x; | |
}); | |
// Compute the default height. | |
if (height === undefined) height = x1 - x0 + dx * 2; | |
// Use the required curve | |
if (typeof curve !== "function") throw new Error(`Unsupported curve`); | |
const parent = d3.create("div"); | |
const body = parent.append("div") | |
.style("overflow-x", "scroll") | |
.style("-webkit-overflow-scrolling", "touch"); | |
const svg = body.append("svg") | |
.attr("viewBox", [-dy * padding / 2, x0 - dx, totalWidth, height]) | |
.attr("width", totalWidth) | |
.attr("height", height) | |
.attr("style", "max-width: 100%; height: auto; height: intrinsic;") | |
.attr("font-family", "sans-serif") | |
.attr("font-size", 12); | |
svg.append("g") | |
.attr("fill", "none") | |
.attr("stroke", stroke) | |
.attr("stroke-opacity", strokeOpacity) | |
.attr("stroke-linecap", strokeLinecap) | |
.attr("stroke-linejoin", strokeLinejoin) | |
.attr("stroke-width", strokeWidth) | |
.selectAll("path") | |
.data(root.links()) | |
.join("path") | |
// .attr("stroke", d => d.prob > 0.5 ? 'red' : 'blue' ) | |
// .attr("fill", "red") | |
.attr("d", d37.link(curve) | |
.x(d => d.y) | |
.y(d => d.x)); | |
const node = svg.append("g") | |
.selectAll("a") | |
.data(root.descendants()) | |
.join("a") | |
.attr("xlink:href", link == null ? null : d => link(d.data, d)) | |
.attr("target", link == null ? null : linkTarget) | |
.attr("transform", d => `translate(${d.y},${d.x})`); | |
node.append("circle") | |
.attr("fill", d => d.children ? stroke : fill) | |
.attr("r", r); | |
title = d => (d.name + ( d.prob)); | |
if (title != null) node.append("title") | |
.text(d => title(d.data, d)); | |
if (L) node.append("text") | |
.attr("dy", "0.32em") | |
.attr("x", d => d.children ? -6 : 6) | |
.attr("text-anchor", d => d.children ? "end" : "start") | |
.attr("paint-order", "stroke") | |
.attr("stroke", 'white') | |
.attr("fill", d => d.data.prob == 1 ? ('red') : ('black') ) | |
.attr("stroke-width", haloWidth) | |
.text((d, i) => L[i]); | |
body.node().scrollBy(totalWidth, 0); | |
return svg.node(); | |
} | |
function TextGrid(data, div_name, { | |
width = 640, // outer width, in pixels | |
height , // outer height, in pixels | |
r = 3, // radius of nodes | |
padding = 1, // horizontal padding for first and last column | |
// text = d => d[2], | |
} = {}){ | |
// console.log("TextGrid", data); | |
// Compute the layout. | |
const dx = 10; | |
const dy = 10; //width / (root.height + padding); | |
const marginTop = 20; | |
const marginRight = 20; | |
const marginBottom = 30; | |
const marginLeft = 30; | |
// Center the tree. | |
let x0 = Infinity; | |
let x1 = -x0; | |
topk = 10; | |
word_length = 20; | |
const rectWidth = 60; | |
const rectTotal = 70; | |
wval = 0 | |
const realWidth = rectTotal * data.length | |
const totalWidth = (realWidth > width) ? realWidth : width; | |
// root.each(d => { | |
// if (d.x > x1) x1 = d.x; | |
// if (d.x < x0) x0 = d.x; | |
// }); | |
// Compute the default height. | |
// if (height === undefined) height = x1 - x0 + dx * 2; | |
if (height === undefined) height = topk * word_length + 10; | |
const parent = d3.create("div"); | |
// parent.append("svg") | |
// .attr("width", width) | |
// .attr("height", height) | |
// .style("position", "absolute") | |
// .style("pointer-events", "none") | |
// .style("z-index", 1); | |
// const svg = d3.create("svg") | |
// // svg = parent.append("svg") | |
// .attr("viewBox", [-dy * padding / 2, x0 - dx, width, height]) | |
// .attr("width", width) | |
// .attr("height", height) | |
// .attr("style", "max-width: 100%; height: auto; height: intrinsic;") | |
// .attr("font-family", "sans-serif") | |
// .attr("font-size", 10); | |
// div.data([1, 2, 4, 8, 16, 32], d => d); | |
// div.enter().append("div").text(d => d); | |
const body = parent.append("div") | |
.style("overflow-x", "scroll") | |
.style("-webkit-overflow-scrolling", "touch"); | |
const svg = body.append("svg") | |
.attr("width", totalWidth) | |
.attr("height", height) | |
.style("display", "block") | |
.attr("font-family", "sans-serif") | |
.attr("font-size", 10); | |
data.forEach(words_list => { | |
// console.log(wval, words_list); | |
words = words_list[2]; // {'t': words_list[2], 'p': words_list[1]}; | |
scores = words_list[1]; | |
words_score = words.map( (x,i) => {return {t: x, p: scores[i]}}) | |
// console.log(words_score); | |
// svg.selectAll("text").enter() | |
// .data(words) | |
// .join("text") | |
// .text((d,i) => (d)) | |
// .attr("x", wval) | |
// .attr("y", ((d,i) => (20 + i*20))) | |
var probs = svg.selectAll("text").enter() | |
.data(words_score).join('g'); | |
probs.append("rect") | |
// .data(words) | |
.attr("x", wval) | |
.attr("y", ((d,i) => ( 10+ i*20))) | |
.attr('width', rectWidth) | |
.attr('height', 15) | |
.attr("color", 'gray') | |
.attr("fill", "gray") | |
// .attr("fill-opacity", "0.2") | |
.attr("fill-opacity", (d) => (d.p)) | |
.attr("stroke-opacity", 0.8) | |
.append("svg:title") | |
.text(function(d){return d.t+":"+d.p;}); | |
probs.append("text") | |
// .data(words) | |
.text((d,i) => (d.t)) | |
.attr("x", wval) | |
.attr("y", ((d,i) => (20 + i*20))) | |
// .attr("fill", 'white') | |
.attr("font-weight", 700); | |
wval = wval + rectTotal; | |
}); | |
body.node().scrollBy(totalWidth, 0); | |
// return svg.node(); | |
return parent.node(); | |
} | |
function attViz(PYTHON_PARAMS) { | |
var $ = jQuery; | |
const params = PYTHON_PARAMS; // HACK: PYTHON_PARAMS is a template marker that is replaced by actual params. | |
const TEXT_SIZE = 15; | |
const BOXWIDTH = 110; | |
const BOXHEIGHT = 22.5; | |
const MATRIX_WIDTH = 115; | |
const CHECKBOX_SIZE = 20; | |
const TEXT_TOP = 30; | |
console.log("d3 version in ffuntions", d3.version) | |
let headColors; | |
try { | |
headColors = d3.scaleOrdinal(d3.schemeCategory10); | |
} catch (err) { | |
console.log('Older d3 version') | |
headColors = d3.scale.category10(); | |
} | |
let config = {}; | |
// globalThis. | |
initialize(); | |
renderVis(); | |
function initialize() { | |
// globalThis.initialize = () => { | |
console.log("init") | |
config.attention = params['attention']; | |
config.filter = params['default_filter']; | |
config.rootDivId = params['root_div_id']; | |
config.nLayers = config.attention[config.filter]['attn'].length; | |
config.nHeads = config.attention[config.filter]['attn'][0].length; | |
config.layers = params['include_layers'] | |
if (params['heads']) { | |
config.headVis = new Array(config.nHeads).fill(false); | |
params['heads'].forEach(x => config.headVis[x] = true); | |
} else { | |
config.headVis = new Array(config.nHeads).fill(true); | |
} | |
config.initialTextLength = config.attention[config.filter].right_text.length; | |
config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer)); | |
config.layer = config.layers[config.layer_seq] | |
// '#' + temp1.root_div_id+ ' #layer' | |
$('#' + config.rootDivId+ ' #layer').empty(); | |
let layerEl = $('#' + config.rootDivId+ ' #layer'); | |
console.log(layerEl) | |
for (const layer of config.layers) { | |
layerEl.append($("<option />").val(layer).text(layer)); | |
} | |
layerEl.val(config.layer).change(); | |
layerEl.on('change', function (e) { | |
config.layer = +e.currentTarget.value; | |
config.layer_seq = config.layers.findIndex(layer => config.layer === layer); | |
renderVis(); | |
}); | |
$('#'+config.rootDivId+' #filter').on('change', function (e) { | |
// $(`#${config.rootDivId} #filter`).on('change', function (e) { | |
config.filter = e.currentTarget.value; | |
renderVis(); | |
}); | |
} | |
function renderVis() { | |
// Load parameters | |
const attnData = config.attention[config.filter]; | |
const leftText = attnData.left_text; | |
const rightText = attnData.right_text; | |
// Select attention for given layer | |
const layerAttention = attnData.attn[config.layer_seq]; | |
// Clear vis | |
$('#'+config.rootDivId+' #vis').empty(); | |
// Determine size of visualization | |
const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP; | |
const svg = d3.select('#'+ config.rootDivId +' #vis') | |
.append('svg') | |
.attr("width", "100%") | |
.attr("height", height + "px"); | |
// Display tokens on left and right side of visualization | |
renderText(svg, leftText, true, layerAttention, 0); | |
renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH); | |
// Render attention arcs | |
renderAttention(svg, layerAttention); | |
// Draw squares at top of visualization, one for each head | |
drawCheckboxes(0, svg, layerAttention); | |
} | |
function renderText(svg, text, isLeft, attention, leftPos) { | |
const textContainer = svg.append("svg:g") | |
.attr("id", isLeft ? "left" : "right"); | |
// Add attention highlights superimposed over words | |
textContainer.append("g") | |
.classed("attentionBoxes", true) | |
.selectAll("g") | |
.data(attention) | |
.enter() | |
.append("g") | |
.attr("head-index", (d, i) => i) | |
.selectAll("rect") | |
.data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights | |
.enter() | |
.append("rect") | |
.attr("x", function () { | |
var headIndex = +this.parentNode.getAttribute("head-index"); | |
return leftPos + boxOffsets(headIndex); | |
}) | |
.attr("y", (+1) * BOXHEIGHT) | |
.attr("width", BOXWIDTH / activeHeads()) | |
.attr("height", BOXHEIGHT) | |
.attr("fill", function () { | |
return headColors(+this.parentNode.getAttribute("head-index")) | |
}) | |
.style("opacity", 0.0); | |
const tokenContainer = textContainer.append("g").selectAll("g") | |
.data(text) | |
.enter() | |
.append("g"); | |
// Add gray background that appears when hovering over text | |
tokenContainer.append("rect") | |
.classed("background", true) | |
.style("opacity", 0.0) | |
.attr("fill", "lightgray") | |
.attr("x", leftPos) | |
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT) | |
.attr("width", BOXWIDTH) | |
.attr("height", BOXHEIGHT); | |
// Add token text | |
const textEl = tokenContainer.append("text") | |
.text(d => d) | |
.attr("font-size", TEXT_SIZE + "px") | |
.style("cursor", "default") | |
.style("-webkit-user-select", "none") | |
.attr("x", leftPos) | |
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT); | |
if (isLeft) { | |
textEl.style("text-anchor", "end") | |
.attr("dx", BOXWIDTH - 0.5 * TEXT_SIZE) | |
.attr("dy", TEXT_SIZE); | |
} else { | |
textEl.style("text-anchor", "start") | |
.attr("dx", +0.5 * TEXT_SIZE) | |
.attr("dy", TEXT_SIZE); | |
} | |
tokenContainer.on("mouseover", function (d, index) { | |
// Show gray background for moused-over token | |
textContainer.selectAll(".background") | |
.style("opacity", (d, i) => i === index ? 1.0 : 0.0) | |
// Reset visibility attribute for any previously highlighted attention arcs | |
svg.select("#attention") | |
.selectAll("line[visibility='visible']") | |
.attr("visibility", null) | |
// Hide group containing attention arcs | |
svg.select("#attention").attr("visibility", "hidden"); | |
// Set to visible appropriate attention arcs to be highlighted | |
if (isLeft) { | |
svg.select("#attention").selectAll("line[left-token-index='" + index + "']").attr("visibility", "visible"); | |
} else { | |
svg.select("#attention").selectAll("line[right-token-index='" + index + "']").attr("visibility", "visible"); | |
} | |
// Update color boxes superimposed over tokens | |
const id = isLeft ? "right" : "left"; | |
const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0; | |
svg.select("#" + id) | |
.selectAll(".attentionBoxes") | |
.selectAll("g") | |
.attr("head-index", (d, i) => i) | |
.selectAll("rect") | |
.attr("x", function () { | |
const headIndex = +this.parentNode.getAttribute("head-index"); | |
return leftPos + boxOffsets(headIndex); | |
}) | |
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT) | |
.attr("width", BOXWIDTH / activeHeads()) | |
.attr("height", BOXHEIGHT) | |
.style("opacity", function (d) { | |
const headIndex = +this.parentNode.getAttribute("head-index"); | |
if (config.headVis[headIndex]) | |
if (d) { | |
return d[index]; | |
} else { | |
return 0.0; | |
} | |
else | |
return 0.0; | |
}); | |
}); | |
textContainer.on("mouseleave", function () { | |
// Unhighlight selected token | |
d3.select(this).selectAll(".background") | |
.style("opacity", 0.0); | |
// Reset visibility attributes for previously selected lines | |
svg.select("#attention") | |
.selectAll("line[visibility='visible']") | |
.attr("visibility", null) ; | |
svg.select("#attention").attr("visibility", "visible"); | |
// Reset highlights superimposed over tokens | |
svg.selectAll(".attentionBoxes") | |
.selectAll("g") | |
.selectAll("rect") | |
.style("opacity", 0.0); | |
}); | |
} | |
function renderAttention(svg, attention) { | |
// Remove previous dom elements | |
svg.select("#attention").remove(); | |
// Add new elements | |
svg.append("g") | |
.attr("id", "attention") // Container for all attention arcs | |
.selectAll(".headAttention") | |
.data(attention) | |
.enter() | |
.append("g") | |
.classed("headAttention", true) // Group attention arcs by head | |
.attr("head-index", (d, i) => i) | |
.selectAll(".tokenAttention") | |
.data(d => d) | |
.enter() | |
.append("g") | |
.classed("tokenAttention", true) // Group attention arcs by left token | |
.attr("left-token-index", (d, i) => i) | |
.selectAll("line") | |
.data(d => d) | |
.enter() | |
.append("line") | |
.attr("x1", BOXWIDTH) | |
.attr("y1", function () { | |
const leftTokenIndex = +this.parentNode.getAttribute("left-token-index") | |
return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2) | |
}) | |
.attr("x2", BOXWIDTH + MATRIX_WIDTH) | |
.attr("y2", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)) | |
.attr("stroke-width", 2) | |
.attr("stroke", function () { | |
const headIndex = +this.parentNode.parentNode.getAttribute("head-index"); | |
return headColors(headIndex) | |
}) | |
.attr("left-token-index", function () { | |
return +this.parentNode.getAttribute("left-token-index") | |
}) | |
.attr("right-token-index", (d, i) => i) | |
; | |
updateAttention(svg) | |
} | |
function updateAttention(svg) { | |
svg.select("#attention") | |
.selectAll("line") | |
.attr("stroke-opacity", function (d) { | |
const headIndex = +this.parentNode.parentNode.getAttribute("head-index"); | |
// If head is selected | |
if (config.headVis[headIndex]) { | |
// Set opacity to attention weight divided by number of active heads | |
return d / activeHeads() | |
} else { | |
return 0.0; | |
} | |
}) | |
} | |
function boxOffsets(i) { | |
const numHeadsAbove = config.headVis.reduce( | |
function (acc, val, cur) { | |
return val && cur < i ? acc + 1 : acc; | |
}, 0); | |
return numHeadsAbove * (BOXWIDTH / activeHeads()); | |
} | |
function activeHeads() { | |
return config.headVis.reduce(function (acc, val) { | |
return val ? acc + 1 : acc; | |
}, 0); | |
} | |
function drawCheckboxes(top, svg) { | |
const checkboxContainer = svg.append("g"); | |
const checkbox = checkboxContainer.selectAll("rect") | |
.data(config.headVis) | |
.enter() | |
.append("rect") | |
.attr("fill", (d, i) => headColors(i)) | |
.attr("x", (d, i) => i * CHECKBOX_SIZE) | |
.attr("y", top) | |
.attr("width", CHECKBOX_SIZE) | |
.attr("height", CHECKBOX_SIZE); | |
function updateCheckboxes() { | |
checkboxContainer.selectAll("rect") | |
.data(config.headVis) | |
.attr("fill", (d, i) => d ? headColors(i): lighten(headColors(i))); | |
} | |
updateCheckboxes(); | |
checkbox.on("click", function (d, i) { | |
if (config.headVis[i] && activeHeads() === 1) return; | |
config.headVis[i] = !config.headVis[i]; | |
updateCheckboxes(); | |
updateAttention(svg); | |
}); | |
checkbox.on("dblclick", function (d, i) { | |
// If we double click on the only active head then reset | |
if (config.headVis[i] && activeHeads() === 1) { | |
config.headVis = new Array(config.nHeads).fill(true); | |
} else { | |
config.headVis = new Array(config.nHeads).fill(false); | |
config.headVis[i] = true; | |
} | |
updateCheckboxes(); | |
updateAttention(svg); | |
}); | |
} | |
function lighten(color) { | |
const c = d3.hsl(color); | |
const increment = (1 - c.l) * 0.6; | |
c.l += increment; | |
c.s -= increment; | |
return c; | |
} | |
function transpose(mat) { | |
return mat[0].map(function (col, i) { | |
return mat.map(function (row) { | |
return row[i]; | |
}); | |
}); | |
} | |
} | |
} | |