|
import tensorflow as tf |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Embedding, LSTM, Dense, Flatten |
|
|
|
def create_text_neural_network(vocab_size, embedding_dim, input_length, num_classes): |
|
model = Sequential([ |
|
Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=input_length), |
|
LSTM(128, return_sequences=True), |
|
LSTM(128), |
|
Dense(64, activation='relu'), |
|
Dense(num_classes, activation='softmax') |
|
]) |
|
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) |
|
return model |
|
|
|
def create_gating_network(input_shape, num_experts): |
|
model = Sequential([ |
|
Flatten(input_shape=input_shape), |
|
Dense(128, activation='relu'), |
|
Dense(num_experts, activation='softmax') |
|
]) |
|
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) |
|
return model |
|
|