Inverted-Pendulum-Neural-Ne.../analysis/time_weighting/time_weighting_centroid_plotter.py

109 lines
4.1 KiB
Python

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
# Add the path and import the weight functions
sys.path.append("/home/judson/Neural-Networks-in-GNC/inverted_pendulum")
from training.time_weighting_functions import weight_functions
# Time grid
t_start, t_end, t_points = 0, 10, 1000
t_span = torch.linspace(t_start, t_end, t_points)
dt = t_span[1] - t_span[0] # constant time step
def compute_metrics(weight_fn, t_span, t_end, min_val=0.01):
# Compute weights from the function
weights = weight_fn(t_span, t_max=t_end, min_val=min_val)
# Total weighted area (integral) using the trapezoidal rule
total_weight = torch.trapz(weights, t_span)
# Weighted mean time
t_mean = torch.trapz(t_span * weights, t_span) / total_weight
# Cumulative integral for the median calculation (using constant dt)
cum_integral = torch.cumsum(weights, dim=0) * dt
half_total = total_weight / 2.0
# Find the first index where the cumulative integral exceeds half the total weight
idx = (cum_integral >= half_total).nonzero()[0, 0]
t_median = t_span[idx].item()
# Late-to-Early Ratio R: split at t_end/2
early_mask = t_span <= (t_end / 2)
late_mask = t_span > (t_end / 2)
early_integral = torch.trapz(weights[early_mask], t_span[early_mask])
late_integral = torch.trapz(weights[late_mask], t_span[late_mask])
R = (late_integral / early_integral).item() if late_integral > 0 else np.nan
return t_mean.item(), t_median, R
# Define the order: non-mirrored function followed by its mirrored counterpart.
ordered_functions = [
"constant",
"linear", "linear_mirrored",
"quadratic", "quadratic_mirrored",
"cubic", "cubic_mirrored",
"square_root", "square_root_mirrored",
"cubic_root", "cubic_root_mirrored",
"inverse", "inverse_mirrored",
"inverse_squared", "inverse_squared_mirrored",
"inverse_cubed", "inverse_cubed_mirrored"
]
# Prepare the data for the table.
results = []
for func_name in ordered_functions:
weight_fn = weight_functions[func_name]
t_mean, t_median, R = compute_metrics(weight_fn, t_span, t_end, min_val=0.01)
results.append({
"Function": func_name,
"t_mean": t_mean,
"t_median": t_median,
"R (Late/Early)": R
})
# Create and display a pandas DataFrame table
df = pd.DataFrame(results)
print(df.to_string(index=False))
# ---------------- Sorted Plotting ----------------
# For each metric, we sort the DataFrame and then create a scatter plot.
fig, axes = plt.subplots(3, 1, figsize=(12, 14))
# Sorted t_mean plot
df_mean_sorted = df.sort_values(by="t_mean").reset_index(drop=True)
x_positions_mean = np.arange(len(df_mean_sorted))
axes[0].scatter(x_positions_mean, df_mean_sorted['t_mean'], color='blue', zorder=3)
axes[0].set_title(r'Weighted Mean Time ($t_{mean})')
axes[0].set_ylabel(r'$t_{mean}$')
axes[0].grid(True, linestyle='--', alpha=0.5)
axes[0].set_xticks(x_positions_mean)
axes[0].set_xticklabels(df_mean_sorted['Function'], rotation=45, ha='right')
# Sorted t_median plot
df_median_sorted = df.sort_values(by="t_median").reset_index(drop=True)
x_positions_median = np.arange(len(df_median_sorted))
axes[1].scatter(x_positions_median, df_median_sorted['t_median'], color='green', zorder=3)
axes[1].set_title(r'Weighted Median Time ($t_{median}$)')
axes[1].set_ylabel(r'$t_{median}$')
axes[1].grid(True, linestyle='--', alpha=0.5)
axes[1].set_xticks(x_positions_median)
axes[1].set_xticklabels(df_median_sorted['Function'], rotation=45, ha='right')
# Sorted R (Late/Early) plot
df_R_sorted = df.sort_values(by="R (Late/Early)").reset_index(drop=True)
x_positions_R = np.arange(len(df_R_sorted))
axes[2].scatter(x_positions_R, df_R_sorted['R (Late/Early)'], color='red', zorder=3)
axes[2].set_title('Late-to-Early Ratio (R)')
axes[2].set_ylabel('R (Late/Early)')
axes[2].set_yscale("log")
axes[2].set_xlabel('Weighting Function')
axes[2].grid(True, linestyle='--', alpha=0.5)
axes[2].set_xticks(x_positions_R)
axes[2].set_xticklabels(df_R_sorted['Function'], rotation=45, ha='right')
plt.tight_layout()
plt.savefig("time_weighting_centroids.png")