@@ -1,10 +1,12 @@
import jax
from jax import random
import jax.numpy as jnp
from train import ModelConfig, TrainConfig
from config_ import ModelConfig, TrainConfig
import optax
from functools import partial
def init_linear_layer(
key: jax.Array,
in_features: int,
The note is not visible to the blocked user.