|
import type {
|
|
INodeInputSlot,
|
|
INodeOutputSlot,
|
|
LGraphCanvas as TLGraphCanvas,
|
|
LGraphNode as TLGraphNode,
|
|
LLink,
|
|
} from "typings/litegraph.js";
|
|
import type { ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js";
|
|
import { app } from "scripts/app.js";
|
|
import {
|
|
IoDirection,
|
|
addConnectionLayoutSupport,
|
|
addMenuItem,
|
|
matchLocalSlotsToServer,
|
|
replaceNode,
|
|
} from "./utils.js";
|
|
import { RgthreeBaseServerNode } from "./base_node.js";
|
|
import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js";
|
|
import { RgthreeBaseServerNodeConstructor } from "typings/rgthree.js";
|
|
import { debounce, wait } from "rgthree/common/shared_utils.js";
|
|
import { removeUnusedInputsFromEnd } from "./utils_inputs_outputs.js";
|
|
import { NodeTypesString } from "./constants.js";
|
|
|
|
|
|
|
|
|
|
|
|
function findMatchingIndexByTypeOrName(
|
|
otherNode: TLGraphNode,
|
|
otherSlot: INodeInputSlot | INodeOutputSlot,
|
|
ctxSlots: INodeInputSlot[] | INodeOutputSlot[],
|
|
) {
|
|
const otherNodeType = (otherNode.type || "").toUpperCase();
|
|
const otherNodeName = (otherNode.title || "").toUpperCase();
|
|
let otherSlotType = otherSlot.type as string;
|
|
if (Array.isArray(otherSlotType) || otherSlotType.includes(",")) {
|
|
otherSlotType = "COMBO";
|
|
}
|
|
const otherSlotName = otherSlot.name.toUpperCase().replace("OPT_", "").replace("_NAME", "");
|
|
let ctxSlotIndex = -1;
|
|
if (["CONDITIONING", "INT", "STRING", "FLOAT", "COMBO"].includes(otherSlotType)) {
|
|
ctxSlotIndex = ctxSlots.findIndex((ctxSlot) => {
|
|
const ctxSlotName = ctxSlot.name.toUpperCase().replace("OPT_", "").replace("_NAME", "");
|
|
let ctxSlotType = ctxSlot.type as string;
|
|
if (Array.isArray(ctxSlotType) || ctxSlotType.includes(",")) {
|
|
ctxSlotType = "COMBO";
|
|
}
|
|
if (ctxSlotType !== otherSlotType) {
|
|
return false;
|
|
}
|
|
|
|
if (
|
|
ctxSlotName === otherSlotName ||
|
|
(ctxSlotName === "SEED" && otherSlotName.includes("SEED")) ||
|
|
(ctxSlotName === "STEP_REFINER" && otherSlotName.includes("AT_STEP")) ||
|
|
(ctxSlotName === "STEP_REFINER" && otherSlotName.includes("REFINER_STEP"))
|
|
) {
|
|
return true;
|
|
}
|
|
|
|
if (
|
|
(otherNodeType.includes("POSITIVE") || otherNodeName.includes("POSITIVE")) &&
|
|
((ctxSlotName === "POSITIVE" && otherSlotType === "CONDITIONING") ||
|
|
(ctxSlotName === "TEXT_POS_G" && otherSlotName.includes("TEXT_G")) ||
|
|
(ctxSlotName === "TEXT_POS_L" && otherSlotName.includes("TEXT_L")))
|
|
) {
|
|
return true;
|
|
}
|
|
if (
|
|
(otherNodeType.includes("NEGATIVE") || otherNodeName.includes("NEGATIVE")) &&
|
|
((ctxSlotName === "NEGATIVE" && otherSlotType === "CONDITIONING") ||
|
|
(ctxSlotName === "TEXT_NEG_G" && otherSlotName.includes("TEXT_G")) ||
|
|
(ctxSlotName === "TEXT_NEG_L" && otherSlotName.includes("TEXT_L")))
|
|
) {
|
|
return true;
|
|
}
|
|
return false;
|
|
});
|
|
} else {
|
|
ctxSlotIndex = ctxSlots.map((s) => s.type).indexOf(otherSlotType);
|
|
}
|
|
return ctxSlotIndex;
|
|
}
|
|
|
|
|
|
|
|
|
|
export class BaseContextNode extends RgthreeBaseServerNode {
|
|
constructor(title: string) {
|
|
super(title);
|
|
}
|
|
|
|
|
|
|
|
___collapsed_width: number = 0;
|
|
|
|
|
|
override get _collapsed_width() {
|
|
return this.___collapsed_width;
|
|
}
|
|
|
|
override set _collapsed_width(width: number) {
|
|
const canvas = app.canvas as TLGraphCanvas;
|
|
const ctx = canvas.canvas.getContext("2d")!;
|
|
const oldFont = ctx.font;
|
|
ctx.font = canvas.title_text_font;
|
|
let title = this.title.trim();
|
|
this.___collapsed_width = 30 + (title ? 10 + ctx.measureText(title).width : 0);
|
|
ctx.font = oldFont;
|
|
}
|
|
|
|
override connectByType<T = any>(
|
|
slot: string | number,
|
|
sourceNode: TLGraphNode,
|
|
sourceSlotType: string,
|
|
optsIn: string,
|
|
): T | null {
|
|
let canConnect =
|
|
super.connectByType &&
|
|
super.connectByType.call(this, slot, sourceNode, sourceSlotType, optsIn);
|
|
if (!super.connectByType) {
|
|
canConnect = LGraphNode.prototype.connectByType.call(
|
|
this,
|
|
slot,
|
|
sourceNode,
|
|
sourceSlotType,
|
|
optsIn,
|
|
);
|
|
}
|
|
if (!canConnect && slot === 0) {
|
|
const ctrlKey = KEY_EVENT_SERVICE.ctrlKey;
|
|
|
|
|
|
|
|
for (const [index, input] of (sourceNode.inputs || []).entries()) {
|
|
if (input.link && !ctrlKey) {
|
|
continue;
|
|
}
|
|
const thisOutputSlot = findMatchingIndexByTypeOrName(sourceNode, input, this.outputs);
|
|
if (thisOutputSlot > -1) {
|
|
this.connect(thisOutputSlot, sourceNode, index);
|
|
}
|
|
}
|
|
}
|
|
return null;
|
|
}
|
|
|
|
override connectByTypeOutput<T = any>(
|
|
slot: string | number,
|
|
sourceNode: TLGraphNode,
|
|
sourceSlotType: string,
|
|
optsIn: string,
|
|
): T | null {
|
|
let canConnect =
|
|
super.connectByTypeOutput &&
|
|
super.connectByTypeOutput.call(this, slot, sourceNode, sourceSlotType, optsIn);
|
|
if (!super.connectByType) {
|
|
canConnect = LGraphNode.prototype.connectByTypeOutput.call(
|
|
this,
|
|
slot,
|
|
sourceNode,
|
|
sourceSlotType,
|
|
optsIn,
|
|
);
|
|
}
|
|
if (!canConnect && slot === 0) {
|
|
const ctrlKey = KEY_EVENT_SERVICE.ctrlKey;
|
|
|
|
|
|
|
|
for (const [index, output] of (sourceNode.outputs || []).entries()) {
|
|
if (output.links?.length && !ctrlKey) {
|
|
continue;
|
|
}
|
|
const thisInputSlot = findMatchingIndexByTypeOrName(sourceNode, output, this.inputs);
|
|
if (thisInputSlot > -1) {
|
|
sourceNode.connect(index, this, thisInputSlot);
|
|
}
|
|
}
|
|
}
|
|
return null;
|
|
}
|
|
|
|
static override setUp(
|
|
comfyClass: ComfyNodeConstructor,
|
|
nodeData: ComfyObjectInfo,
|
|
ctxClass: RgthreeBaseServerNodeConstructor,
|
|
) {
|
|
RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, ctxClass);
|
|
|
|
|
|
|
|
|
|
|
|
wait(500).then(() => {
|
|
LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"] =
|
|
LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"] || [];
|
|
LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"].push(comfyClass.comfyClass);
|
|
});
|
|
}
|
|
|
|
static override onRegisteredForOverride(comfyClass: any, ctxClass: any) {
|
|
addConnectionLayoutSupport(ctxClass, app, [
|
|
["Left", "Right"],
|
|
["Right", "Left"],
|
|
]);
|
|
setTimeout(() => {
|
|
ctxClass.category = comfyClass.category;
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
class ContextNode extends BaseContextNode {
|
|
static override title = NodeTypesString.CONTEXT;
|
|
static override type = NodeTypesString.CONTEXT;
|
|
static comfyClass = NodeTypesString.CONTEXT;
|
|
|
|
constructor(title = ContextNode.title) {
|
|
super(title);
|
|
}
|
|
|
|
static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) {
|
|
BaseContextNode.setUp(comfyClass, nodeData, ContextNode);
|
|
}
|
|
|
|
static override onRegisteredForOverride(comfyClass: any, ctxClass: any) {
|
|
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
|
|
addMenuItem(ContextNode, app, {
|
|
name: "Convert To Context Big",
|
|
callback: (node) => {
|
|
replaceNode(node, ContextBigNode.type);
|
|
},
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
class ContextBigNode extends BaseContextNode {
|
|
static override title = NodeTypesString.CONTEXT_BIG;
|
|
static override type = NodeTypesString.CONTEXT_BIG;
|
|
static comfyClass = NodeTypesString.CONTEXT_BIG;
|
|
|
|
constructor(title = ContextBigNode.title) {
|
|
super(title);
|
|
}
|
|
|
|
static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) {
|
|
BaseContextNode.setUp(comfyClass, nodeData, ContextBigNode);
|
|
}
|
|
|
|
static override onRegisteredForOverride(comfyClass: any, ctxClass: any) {
|
|
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
|
|
addMenuItem(ContextBigNode, app, {
|
|
name: "Convert To Context (Original)",
|
|
callback: (node) => {
|
|
replaceNode(node, ContextNode.type);
|
|
},
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class BaseContextMultiCtxInputNode extends BaseContextNode {
|
|
private stabilizeBound = this.stabilize.bind(this);
|
|
|
|
constructor(title: string) {
|
|
super(title);
|
|
|
|
this.addContextInput(5);
|
|
}
|
|
|
|
private addContextInput(num = 1) {
|
|
for (let i = 0; i < num; i++) {
|
|
this.addInput(`ctx_${String(this.inputs.length + 1).padStart(2, "0")}`, "RGTHREE_CONTEXT");
|
|
}
|
|
}
|
|
|
|
override onConnectionsChange(
|
|
type: number,
|
|
slotIndex: number,
|
|
isConnected: boolean,
|
|
link: LLink,
|
|
ioSlot: INodeInputSlot | INodeOutputSlot,
|
|
): void {
|
|
super.onConnectionsChange?.apply(this, [...arguments] as any);
|
|
if (type === LiteGraph.INPUT) {
|
|
this.scheduleStabilize();
|
|
}
|
|
}
|
|
|
|
private scheduleStabilize(ms = 64) {
|
|
return debounce(this.stabilizeBound, 64);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private stabilize() {
|
|
removeUnusedInputsFromEnd(this, 4);
|
|
this.addContextInput();
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
class ContextSwitchNode extends BaseContextMultiCtxInputNode {
|
|
static override title = NodeTypesString.CONTEXT_SWITCH;
|
|
static override type = NodeTypesString.CONTEXT_SWITCH;
|
|
static comfyClass = NodeTypesString.CONTEXT_SWITCH;
|
|
|
|
constructor(title = ContextSwitchNode.title) {
|
|
super(title);
|
|
}
|
|
|
|
static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) {
|
|
BaseContextNode.setUp(comfyClass, nodeData, ContextSwitchNode);
|
|
}
|
|
|
|
static override onRegisteredForOverride(comfyClass: any, ctxClass: any) {
|
|
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
|
|
addMenuItem(ContextSwitchNode, app, {
|
|
name: "Convert To Context Switch Big",
|
|
callback: (node) => {
|
|
replaceNode(node, ContextSwitchBigNode.type);
|
|
},
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
class ContextSwitchBigNode extends BaseContextMultiCtxInputNode {
|
|
static override title = NodeTypesString.CONTEXT_SWITCH_BIG;
|
|
static override type = NodeTypesString.CONTEXT_SWITCH_BIG;
|
|
static comfyClass = NodeTypesString.CONTEXT_SWITCH_BIG;
|
|
|
|
constructor(title = ContextSwitchBigNode.title) {
|
|
super(title);
|
|
}
|
|
|
|
static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) {
|
|
BaseContextNode.setUp(comfyClass, nodeData, ContextSwitchBigNode);
|
|
}
|
|
|
|
static override onRegisteredForOverride(comfyClass: any, ctxClass: any) {
|
|
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
|
|
addMenuItem(ContextSwitchBigNode, app, {
|
|
name: "Convert To Context Switch",
|
|
callback: (node) => {
|
|
replaceNode(node, ContextSwitchNode.type);
|
|
},
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
class ContextMergeNode extends BaseContextMultiCtxInputNode {
|
|
static override title = NodeTypesString.CONTEXT_MERGE;
|
|
static override type = NodeTypesString.CONTEXT_MERGE;
|
|
static comfyClass = NodeTypesString.CONTEXT_MERGE;
|
|
|
|
constructor(title = ContextMergeNode.title) {
|
|
super(title);
|
|
}
|
|
|
|
static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) {
|
|
BaseContextNode.setUp(comfyClass, nodeData, ContextMergeNode);
|
|
}
|
|
|
|
static override onRegisteredForOverride(comfyClass: any, ctxClass: any) {
|
|
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
|
|
addMenuItem(ContextMergeNode, app, {
|
|
name: "Convert To Context Merge Big",
|
|
callback: (node) => {
|
|
replaceNode(node, ContextMergeBigNode.type);
|
|
},
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
class ContextMergeBigNode extends BaseContextMultiCtxInputNode {
|
|
static override title = NodeTypesString.CONTEXT_MERGE_BIG;
|
|
static override type = NodeTypesString.CONTEXT_MERGE_BIG;
|
|
static comfyClass = NodeTypesString.CONTEXT_MERGE_BIG;
|
|
|
|
constructor(title = ContextMergeBigNode.title) {
|
|
super(title);
|
|
}
|
|
|
|
static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) {
|
|
BaseContextNode.setUp(comfyClass, nodeData, ContextMergeBigNode);
|
|
}
|
|
|
|
static override onRegisteredForOverride(comfyClass: any, ctxClass: any) {
|
|
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
|
|
addMenuItem(ContextMergeBigNode, app, {
|
|
name: "Convert To Context Switch",
|
|
callback: (node) => {
|
|
replaceNode(node, ContextMergeNode.type);
|
|
},
|
|
});
|
|
}
|
|
}
|
|
|
|
const contextNodes = [
|
|
ContextNode,
|
|
ContextBigNode,
|
|
ContextSwitchNode,
|
|
ContextSwitchBigNode,
|
|
ContextMergeNode,
|
|
ContextMergeBigNode,
|
|
];
|
|
const contextTypeToServerDef: { [type: string]: ComfyObjectInfo } = {};
|
|
|
|
function fixBadConfigs(node: ContextNode) {
|
|
|
|
|
|
const wrongName = node.outputs.find((o, i) => o.name === "CLIP_HEIGTH");
|
|
if (wrongName) {
|
|
wrongName.name = "CLIP_HEIGHT";
|
|
}
|
|
}
|
|
|
|
app.registerExtension({
|
|
name: "rgthree.Context",
|
|
async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) {
|
|
|
|
for (const ctxClass of contextNodes) {
|
|
if (nodeData.name === ctxClass.type) {
|
|
contextTypeToServerDef[ctxClass.type] = nodeData;
|
|
ctxClass.setUp(nodeType, nodeData);
|
|
break;
|
|
}
|
|
}
|
|
},
|
|
|
|
async nodeCreated(node: TLGraphNode) {
|
|
const type = node.type || (node.constructor as any).type;
|
|
const serverDef = type && contextTypeToServerDef[type];
|
|
if (serverDef) {
|
|
fixBadConfigs(node as ContextNode);
|
|
matchLocalSlotsToServer(node, IoDirection.OUTPUT, serverDef);
|
|
|
|
if (!type!.includes("Switch") && !type!.includes("Merge")) {
|
|
matchLocalSlotsToServer(node, IoDirection.INPUT, serverDef);
|
|
}
|
|
|
|
}
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
async loadedGraphNode(node: TLGraphNode) {
|
|
const type = node.type || (node.constructor as any).type;
|
|
const serverDef = type && contextTypeToServerDef[type];
|
|
if (serverDef) {
|
|
fixBadConfigs(node as ContextNode);
|
|
matchLocalSlotsToServer(node, IoDirection.OUTPUT, serverDef);
|
|
|
|
if (!type!.includes("Switch") && !type!.includes("Merge")) {
|
|
matchLocalSlotsToServer(node, IoDirection.INPUT, serverDef);
|
|
}
|
|
}
|
|
},
|
|
});
|
|
|