Updated data generation code

This commit is contained in:
2025-09-01 14:46:34 -04:00
parent d998f6de4c
commit e018238935
14 changed files with 709 additions and 123 deletions

View File

@@ -46,7 +46,7 @@ def init_fn(key: jax.Array, config: ModelConfig):
def forward(params: dict, input_timesteps: jax.Array, config: ModelConfig):
"""
Model's forward function. Takes in the parameters and inptu timesteps, returns predictions
Model's forward function. Takes in the parameters and input timesteps, returns predictions
"""
batch_size, num_agents, _ = input_timesteps.shape
@@ -119,6 +119,7 @@ def train_model(config: ModelConfig, inputs: jax.Array, targets: jax.Array,
params, opt_state, loss_val = update_step(params, opt_state, x, y, config)
running_loss += loss_val
epoch_loss = running_loss / num_batches
loss_history[f"epoch_{epoch}"].append(epoch_loss)