Skip to content

How to train a DropoutMLP #88

Answered by danieldjohnson
wlib asked this question in Q&A
Discussion options

You must be logged in to vote

One way to adapt the vanilla MLP training test to support a DropoutMLP is to construct the RandomStream inside the loss function using the rng argument:

mlp = simple_mlp.DropoutMLP.from_config(  # <- using a DropoutMLP instead of a regular MLP
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[2, 32, 32, 2],
    drop_rate=0.2,
)

const_xor_inputs = pz.nx.wrap(
    jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=jnp.float32),
    "batch",
    "features",
)
const_xor_labels = jnp.array(
    [[0, 1], [1, 0], [1, 0], [0, 1]], dtype=jnp.float32
)

def loss_fn(model, rng, state, xor_inputs, xor_labels):
  # Change starts here!
  # First build a random stream from the …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@wlib
Comment options

Answer selected by danieldjohnson
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants