Updated data generation code

This commit is contained in:
2025-09-01 14:46:34 -04:00
parent d998f6de4c
commit e018238935
14 changed files with 709 additions and 123 deletions

View File

@@ -51,7 +51,7 @@ def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_confi
all_targets = all_targets[all_indices]
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
noise_shape = full_dataset_inputs.shape
noise_shape = all_inputs.shape
if train_config.noise_type == NoiseType.NORMAL:
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
elif train_config.noise_type == NoiseType.UNIFORM:
@@ -60,18 +60,18 @@ def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_confi
minval=-train_config.noise_level,
maxval=train_config.noise_level
)
full_dataset_inputs += np.array(noise) # Add noise to inputs
all_inputs += np.array(noise) # Add noise to inputs
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
all_inputs = np.expand_dims(all_inputs, axis=-1)
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
# Create batches
num_samples = full_dataset_inputs.shape[0]
num_samples = all_inputs.shape[0]
num_batches = num_samples // batch_size
# Truncate to full batches
truncated_inputs = full_dataset_inputs[:num_batches * batch_size]
truncated_inputs = all_inputs[:num_batches * batch_size]
truncated_targets = full_dataset_targets[:num_batches * batch_size]
# Reshape into batches
@@ -81,6 +81,40 @@ def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_confi
return batched_inputs, batched_targets
def f1_score_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
Compute the F1 score between two numpy arrays.
Parameters
----------
y_true : np.ndarray
Ground truth (correct) labels.
y_pred : np.ndarray
Predicted labels.
Returns
-------
float
The F1 score.
"""
# Ensure binary arrays (0 or 1)
y_true = np.asarray(y_true).astype(int)
y_pred = np.asarray(y_pred).astype(int)
# Compute True Positives, False Positives, and False Negatives
tp = np.sum((y_true == 1) & (y_pred == 1))
fp = np.sum((y_true == 0) & (y_pred == 1))
fn = np.sum((y_true == 1) & (y_pred == 0))
# Precision and Recall
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
# F1 Score
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)
def calculate_f1_score(
params: dict,
model_config: ModelConfig,
@@ -111,7 +145,7 @@ def calculate_f1_score(
true_flat = true_graph.flatten()
pred_flat = predicted_graph.flatten()
return f1_score(true_flat, pred_flat)
return f1_score_np(true_flat, pred_flat)
def main():
"""Main script to run the training and evaluation pipeline."""
@@ -125,7 +159,7 @@ def main():
print(f"Please run the data generation script for '{train_config.data_directory}' first.")
return
print(f"🚀 Starting training pipeline for '{train_config.data_directory}' data.")
print(f"Starting training pipeline for '{train_config.data_directory}' data.")
# Get sorted list of agent directories
agent_dirs = sorted(
@@ -146,7 +180,8 @@ def main():
os.makedirs(results_dir, exist_ok=True)
subdir = str(train_config.noise_type)
sub_results_dir = os.path.join(results_dir, subdir)
subsubdir = str(train_config.noise_level)
sub_results_dir = os.path.join(results_dir, subdir, subsubdir)
os.makedirs(sub_results_dir, exist_ok=True)
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")