training and plotting updates

This commit is contained in:
2025-08-04 12:44:35 -04:00
parent 5a0c479c1e
commit d998f6de4c
9 changed files with 457 additions and 71 deletions

View File

@@ -1,10 +1,12 @@
import jax
from jax import random
import jax.numpy as jnp
from train import ModelConfig, TrainConfig
from config_ import ModelConfig, TrainConfig
import optax
from functools import partial
def init_linear_layer(
key: jax.Array,
in_features: int,