How to train a DropoutMLP
#88
-
I can train a normal |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
One way to adapt the vanilla MLP training test to support a mlp = simple_mlp.DropoutMLP.from_config( # <- using a DropoutMLP instead of a regular MLP
name="mlp",
init_base_rng=jax.random.key(0),
feature_sizes=[2, 32, 32, 2],
drop_rate=0.2,
)
const_xor_inputs = pz.nx.wrap(
jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=jnp.float32),
"batch",
"features",
)
const_xor_labels = jnp.array(
[[0, 1], [1, 0], [1, 0], [0, 1]], dtype=jnp.float32
)
def loss_fn(model, rng, state, xor_inputs, xor_labels):
# Change starts here!
# First build a random stream from the `rng`, then pass it to the model.
random_stream = pz.RandomStream.from_base_key(rng)
scale = 1 + jax.random.uniform(random_stream.next_key(), shape=(1,))
model_out = model(xor_inputs, random_stream=random_stream)
# (end change)
log_probs = jax.nn.log_softmax(
model_out.unwrap("batch", "features"), axis=-1
)
losses = -scale * log_probs * xor_labels
loss = jnp.sum(losses) / 4
return (loss, state + 1, {"loss": loss, "count": state})
trainer = basic_training.StatefulTrainer.build(
root_rng=jax.random.key(42),
model=mlp,
optimizer_def=optax.adam(0.1),
loss_fn=loss_fn,
initial_loss_fn_state=100,
)
outputs = []
for _ in range(100):
outputs.append(
trainer.step(xor_inputs=const_xor_inputs, xor_labels=const_xor_labels)
) It sounds like maybe you were trying to pass a RandomStream through Adding this as an additional test is a good suggestion, thanks! |
Beta Was this translation helpful? Give feedback.
One way to adapt the vanilla MLP training test to support a
DropoutMLP
is to construct theRandomStream
inside the loss function using therng
argument: