|
|
|
import asyncio |
|
import pickle as pk |
|
import time |
|
import warnings |
|
|
|
import matplotlib as mpl |
|
import matplotlib.pyplot as plt |
|
import mpl_toolkits.mplot3d.art3d as art3d |
|
import numpy as np |
|
import torch |
|
from matplotlib import cm |
|
from matplotlib.animation import FuncAnimation |
|
from matplotlib.gridspec import GridSpec |
|
from matplotlib.patches import Circle, PathPatch |
|
from mpl_toolkits.mplot3d import Axes3D, axes3d |
|
from sklearn.decomposition import PCA |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def display_words(words, vector_list, score, bold): |
|
|
|
plt.ioff() |
|
fig = plt.figure() |
|
|
|
ax = fig.add_subplot(111, projection="3d") |
|
plt.rcParams["image.cmap"] = "magma" |
|
colormap = cm.get_cmap("magma") |
|
|
|
|
|
score = np.array(score) |
|
norm = plt.Normalize(0, 10) |
|
colors = colormap(norm(score)) |
|
ax.xaxis.pane.fill = False |
|
ax.yaxis.pane.fill = False |
|
ax.w_zaxis.set_pane_color( |
|
(0.87, 0.91, 0.94, 0.8) |
|
) |
|
ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) |
|
ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) |
|
ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) |
|
|
|
|
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
ax.set_zticks([]) |
|
ax.grid(False) |
|
|
|
data_pca = vector_list |
|
|
|
if len(data_pca) > 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_pca0 = np.repeat(data_pca[0][None, :], len(data_pca) - 1, axis=0) |
|
|
|
|
|
data = norm_distance_v(data_pca0, data_pca[1:], score[1:]) |
|
|
|
else: |
|
data = data_pca.transpose() |
|
|
|
( |
|
x, |
|
y, |
|
z, |
|
) = data |
|
|
|
center_x = x[0] |
|
center_y = y[0] |
|
center_z = z[0] |
|
|
|
ax.autoscale(enable=True, axis="both", tight=True) |
|
|
|
|
|
|
|
|
|
for i, word in enumerate(words): |
|
if i == bold: |
|
fontsize = "large" |
|
fontweight = "demibold" |
|
else: |
|
fontsize = "medium" |
|
fontweight = "normal" |
|
|
|
ax.text( |
|
x[i], |
|
y[i], |
|
z[i] + 0.05, |
|
word, |
|
fontsize=fontsize, |
|
fontweight=fontweight, |
|
alpha=1, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax.scatter(x, y, z, c="black", marker="o", s=75, cmap="magma", vmin=0, vmax=10) |
|
scatter = ax.scatter( |
|
x, |
|
y, |
|
z, |
|
marker="o", |
|
s=70, |
|
c=colors, |
|
cmap="magma", |
|
vmin=0, |
|
vmax=10, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig.colorbar( |
|
cm.ScalarMappable(norm=mpl.colors.Normalize(0, 10), cmap="magma"), |
|
ax=ax, |
|
orientation="horizontal", |
|
) |
|
|
|
|
|
def update(frame): |
|
distance = 0.5 * (score.max() - score.min()) |
|
ax.set_xlim(center_x - distance, center_x + distance) |
|
ax.set_ylim(center_y - distance, center_y + distance) |
|
ax.set_zlim(center_z - distance, center_z + distance) |
|
ax.view_init(elev=20, azim=frame) |
|
|
|
|
|
|
|
|
|
frames = np.arange(0, 360, 5) |
|
ani = FuncAnimation(fig, update, frames=frames, interval=120) |
|
|
|
ani.save("3d_rotation.gif", writer="pillow", dpi=140) |
|
plt.close(fig) |
|
|
|
|
|
|
|
def norm_distance_v(orig, points, distances): |
|
|
|
|
|
AB = points - orig |
|
|
|
|
|
Normalized_AB = AB / np.linalg.norm(AB, axis=1, keepdims=True) |
|
|
|
|
|
d = 10 - (distances.reshape(-1, 1) * 1) |
|
|
|
|
|
C = orig + (Normalized_AB * d) |
|
C = np.append([orig[0]], C, axis=0) |
|
|
|
return np.array([C[:, 0], C[:, 1], C[:, 2]]) |
|
|