109 lines
4.1 KiB
Python
109 lines
4.1 KiB
Python
import torch
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import sys
|
|
|
|
# Add the path and import the weight functions
|
|
sys.path.append("/home/judson/Neural-Networks-in-GNC/inverted_pendulum")
|
|
from training.time_weighting_functions import weight_functions
|
|
|
|
# Time grid
|
|
t_start, t_end, t_points = 0, 10, 1000
|
|
t_span = torch.linspace(t_start, t_end, t_points)
|
|
dt = t_span[1] - t_span[0] # constant time step
|
|
|
|
def compute_metrics(weight_fn, t_span, t_end, min_val=0.01):
|
|
# Compute weights from the function
|
|
weights = weight_fn(t_span, t_max=t_end, min_val=min_val)
|
|
|
|
# Total weighted area (integral) using the trapezoidal rule
|
|
total_weight = torch.trapz(weights, t_span)
|
|
|
|
# Weighted mean time
|
|
t_mean = torch.trapz(t_span * weights, t_span) / total_weight
|
|
|
|
# Cumulative integral for the median calculation (using constant dt)
|
|
cum_integral = torch.cumsum(weights, dim=0) * dt
|
|
half_total = total_weight / 2.0
|
|
# Find the first index where the cumulative integral exceeds half the total weight
|
|
idx = (cum_integral >= half_total).nonzero()[0, 0]
|
|
t_median = t_span[idx].item()
|
|
|
|
# Late-to-Early Ratio R: split at t_end/2
|
|
early_mask = t_span <= (t_end / 2)
|
|
late_mask = t_span > (t_end / 2)
|
|
early_integral = torch.trapz(weights[early_mask], t_span[early_mask])
|
|
late_integral = torch.trapz(weights[late_mask], t_span[late_mask])
|
|
R = (late_integral / early_integral).item() if late_integral > 0 else np.nan
|
|
|
|
return t_mean.item(), t_median, R
|
|
|
|
# Define the order: non-mirrored function followed by its mirrored counterpart.
|
|
ordered_functions = [
|
|
"constant",
|
|
"linear", "linear_mirrored",
|
|
"quadratic", "quadratic_mirrored",
|
|
"cubic", "cubic_mirrored",
|
|
"square_root", "square_root_mirrored",
|
|
"cubic_root", "cubic_root_mirrored",
|
|
"inverse", "inverse_mirrored",
|
|
"inverse_squared", "inverse_squared_mirrored",
|
|
"inverse_cubed", "inverse_cubed_mirrored"
|
|
]
|
|
|
|
# Prepare the data for the table.
|
|
results = []
|
|
for func_name in ordered_functions:
|
|
weight_fn = weight_functions[func_name]
|
|
t_mean, t_median, R = compute_metrics(weight_fn, t_span, t_end, min_val=0.01)
|
|
results.append({
|
|
"Function": func_name,
|
|
"t_mean": t_mean,
|
|
"t_median": t_median,
|
|
"R (Late/Early)": R
|
|
})
|
|
|
|
# Create and display a pandas DataFrame table
|
|
df = pd.DataFrame(results)
|
|
print(df.to_string(index=False))
|
|
|
|
# ---------------- Sorted Plotting ----------------
|
|
# For each metric, we sort the DataFrame and then create a scatter plot.
|
|
fig, axes = plt.subplots(3, 1, figsize=(12, 14))
|
|
|
|
# Sorted t_mean plot
|
|
df_mean_sorted = df.sort_values(by="t_mean").reset_index(drop=True)
|
|
x_positions_mean = np.arange(len(df_mean_sorted))
|
|
axes[0].scatter(x_positions_mean, df_mean_sorted['t_mean'], color='blue', zorder=3)
|
|
axes[0].set_title(r'Weighted Mean Time ($t_{mean})')
|
|
axes[0].set_ylabel(r'$t_{mean}$')
|
|
axes[0].grid(True, linestyle='--', alpha=0.5)
|
|
axes[0].set_xticks(x_positions_mean)
|
|
axes[0].set_xticklabels(df_mean_sorted['Function'], rotation=45, ha='right')
|
|
|
|
# Sorted t_median plot
|
|
df_median_sorted = df.sort_values(by="t_median").reset_index(drop=True)
|
|
x_positions_median = np.arange(len(df_median_sorted))
|
|
axes[1].scatter(x_positions_median, df_median_sorted['t_median'], color='green', zorder=3)
|
|
axes[1].set_title(r'Weighted Median Time ($t_{median}$)')
|
|
axes[1].set_ylabel(r'$t_{median}$')
|
|
axes[1].grid(True, linestyle='--', alpha=0.5)
|
|
axes[1].set_xticks(x_positions_median)
|
|
axes[1].set_xticklabels(df_median_sorted['Function'], rotation=45, ha='right')
|
|
|
|
# Sorted R (Late/Early) plot
|
|
df_R_sorted = df.sort_values(by="R (Late/Early)").reset_index(drop=True)
|
|
x_positions_R = np.arange(len(df_R_sorted))
|
|
axes[2].scatter(x_positions_R, df_R_sorted['R (Late/Early)'], color='red', zorder=3)
|
|
axes[2].set_title('Late-to-Early Ratio (R)')
|
|
axes[2].set_ylabel('R (Late/Early)')
|
|
axes[2].set_yscale("log")
|
|
axes[2].set_xlabel('Weighting Function')
|
|
axes[2].grid(True, linestyle='--', alpha=0.5)
|
|
axes[2].set_xticks(x_positions_R)
|
|
axes[2].set_xticklabels(df_R_sorted['Function'], rotation=45, ha='right')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig("time_weighting_centroids.png")
|