Updated data generation code
This commit is contained in:
@@ -51,7 +51,7 @@ def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_confi
|
||||
all_targets = all_targets[all_indices]
|
||||
|
||||
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
|
||||
noise_shape = full_dataset_inputs.shape
|
||||
noise_shape = all_inputs.shape
|
||||
if train_config.noise_type == NoiseType.NORMAL:
|
||||
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
|
||||
elif train_config.noise_type == NoiseType.UNIFORM:
|
||||
@@ -60,18 +60,18 @@ def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_confi
|
||||
minval=-train_config.noise_level,
|
||||
maxval=train_config.noise_level
|
||||
)
|
||||
full_dataset_inputs += np.array(noise) # Add noise to inputs
|
||||
all_inputs += np.array(noise) # Add noise to inputs
|
||||
|
||||
|
||||
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
|
||||
all_inputs = np.expand_dims(all_inputs, axis=-1)
|
||||
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
|
||||
|
||||
# Create batches
|
||||
num_samples = full_dataset_inputs.shape[0]
|
||||
num_samples = all_inputs.shape[0]
|
||||
num_batches = num_samples // batch_size
|
||||
|
||||
# Truncate to full batches
|
||||
truncated_inputs = full_dataset_inputs[:num_batches * batch_size]
|
||||
truncated_inputs = all_inputs[:num_batches * batch_size]
|
||||
truncated_targets = full_dataset_targets[:num_batches * batch_size]
|
||||
|
||||
# Reshape into batches
|
||||
@@ -81,6 +81,40 @@ def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_confi
|
||||
|
||||
return batched_inputs, batched_targets
|
||||
|
||||
def f1_score_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
Compute the F1 score between two numpy arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : np.ndarray
|
||||
Ground truth (correct) labels.
|
||||
y_pred : np.ndarray
|
||||
Predicted labels.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The F1 score.
|
||||
"""
|
||||
# Ensure binary arrays (0 or 1)
|
||||
y_true = np.asarray(y_true).astype(int)
|
||||
y_pred = np.asarray(y_pred).astype(int)
|
||||
|
||||
# Compute True Positives, False Positives, and False Negatives
|
||||
tp = np.sum((y_true == 1) & (y_pred == 1))
|
||||
fp = np.sum((y_true == 0) & (y_pred == 1))
|
||||
fn = np.sum((y_true == 1) & (y_pred == 0))
|
||||
|
||||
# Precision and Recall
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||
|
||||
# F1 Score
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
def calculate_f1_score(
|
||||
params: dict,
|
||||
model_config: ModelConfig,
|
||||
@@ -111,7 +145,7 @@ def calculate_f1_score(
|
||||
true_flat = true_graph.flatten()
|
||||
pred_flat = predicted_graph.flatten()
|
||||
|
||||
return f1_score(true_flat, pred_flat)
|
||||
return f1_score_np(true_flat, pred_flat)
|
||||
|
||||
def main():
|
||||
"""Main script to run the training and evaluation pipeline."""
|
||||
@@ -125,7 +159,7 @@ def main():
|
||||
print(f"Please run the data generation script for '{train_config.data_directory}' first.")
|
||||
return
|
||||
|
||||
print(f"🚀 Starting training pipeline for '{train_config.data_directory}' data.")
|
||||
print(f"Starting training pipeline for '{train_config.data_directory}' data.")
|
||||
|
||||
# Get sorted list of agent directories
|
||||
agent_dirs = sorted(
|
||||
@@ -146,7 +180,8 @@ def main():
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
subdir = str(train_config.noise_type)
|
||||
sub_results_dir = os.path.join(results_dir, subdir)
|
||||
subsubdir = str(train_config.noise_level)
|
||||
sub_results_dir = os.path.join(results_dir, subdir, subsubdir)
|
||||
os.makedirs(sub_results_dir, exist_ok=True)
|
||||
|
||||
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
|
||||
|
||||
Reference in New Issue
Block a user