Inverted-Pendulum-Neural-Ne.../analysis/time_weighting/centroid_convergence_plotter.py

292 lines
12 KiB
Python

import os
import sys
import json
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from adjustText import adjust_text # For automatic label adjustment
# ---------------------------------------------------------------------
# Access your custom modules
# ---------------------------------------------------------------------
sys.path.append("/home/judson/Neural-Networks-in-GNC/inverted_pendulum")
from training.time_weighting_functions import weight_functions
from analysis.analysis_conditions import analysis_conditions
from numpy import pi
# ---------------------------------------------------------------------
# User parameters
# ---------------------------------------------------------------------
condition_base_dir = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/time_weighting"
final_epoch = 25 # The epoch at which to measure final loss
convergence_ref_epoch = 100 # The epoch used to define "final constant loss"
convergence_threshold_percentage = 10 # e.g., 10%
time_start, time_end, time_points = 0, 10, 1000
t_span = torch.linspace(time_start, time_end, time_points)
dt = t_span[1] - t_span[0]
# ---------------------------------------------------------------------
# 1) Compute weighting metrics (t_mean, t_median, R) for each function
# ---------------------------------------------------------------------
def compute_metrics(weight_fn, t_span, t_end, min_val=0.01):
"""
Returns (t_mean, t_median, R) for a given weighting function.
Here, R is computed as:
R = (integral from t_end/2 to t_end) / (integral from 0 to t_end/2),
so that a smaller R indicates more emphasis on early times.
"""
weights = weight_fn(t_span, t_max=t_end, min_val=min_val)
total_weight = torch.trapz(weights, t_span)
# Weighted mean
t_mean = torch.trapz(t_span * weights, t_span) / total_weight
# Weighted median
cum_integral = torch.cumsum(weights, dim=0) * dt
half_total = total_weight / 2.0
idx = (cum_integral >= half_total).nonzero()[0, 0]
t_median = t_span[idx].item()
# Inverse ratio: Late vs Early so that a lower value means more early weighting
early_mask = t_span <= (t_end / 2)
late_mask = t_span > (t_end / 2)
early_integral = torch.trapz(weights[early_mask], t_span[early_mask])
late_integral = torch.trapz(weights[late_mask], t_span[late_mask])
R = (late_integral / early_integral).item() if early_integral > 0 else np.nan
return t_mean.item(), t_median, R
# Ordered weighting functions
ordered_functions = [
"constant",
"linear", "linear_mirrored",
"quadratic", "quadratic_mirrored",
"cubic", "cubic_mirrored",
"square_root", "square_root_mirrored",
"cubic_root", "cubic_root_mirrored",
"inverse", "inverse_mirrored",
"inverse_squared", "inverse_squared_mirrored",
"inverse_cubed", "inverse_cubed_mirrored"
]
metrics_dict = {}
for fn_name in ordered_functions:
fn = weight_functions[fn_name]
t_mean, t_median, R = compute_metrics(fn, t_span, time_end, min_val=0.01)
metrics_dict[fn_name] = {"t_mean": t_mean, "t_median": t_median, "R": R}
# ---------------------------------------------------------------------
# 2) Helper to compute loss using constant weighting
# ---------------------------------------------------------------------
def compute_constant_loss(theta_array, time_array, desired_theta):
"""
Computes the mean squared error loss: mean((theta - desired_theta)^2).
(This is a simplified loss; adjust if you wish to include weighting.)
"""
theta_tensor = torch.tensor(theta_array, dtype=torch.float32)
desired_tensor = torch.full_like(theta_tensor, desired_theta)
loss_val = torch.mean((theta_tensor - desired_tensor)**2)
return loss_val.item()
# ---------------------------------------------------------------------
# 3) Parse each condition folder and gather final loss & convergence data
# ---------------------------------------------------------------------
def get_final_loss_and_convergence(data_dict, desired_theta, final_epoch, ref_epoch):
"""
data_dict: loaded JSON for a weighting function with keys:
"epochs", "theta_over_epochs" (list of arrays), "time" (array)
Returns final loss (computed at final_epoch) and placeholder None.
"""
epochs = data_dict["epochs"]
theta_over_epochs = data_dict["theta_over_epochs"]
time_array = data_dict["time"]
if final_epoch not in epochs:
return None, None
final_idx = epochs.index(final_epoch)
final_theta = theta_over_epochs[final_idx]
final_loss = compute_constant_loss(final_theta, time_array, desired_theta)
return final_loss, None
def find_convergence_epoch(data_dict, desired_theta, threshold):
"""
Returns the first epoch whose loss (computed with constant MSE) is <= threshold.
If none is found, returns np.nan.
"""
epochs = data_dict["epochs"]
theta_over_epochs = data_dict["theta_over_epochs"]
time_array = data_dict["time"]
for ep, thetas in zip(epochs, theta_over_epochs):
loss_val = compute_constant_loss(thetas, time_array, desired_theta)
if loss_val <= threshold:
return ep
return np.nan
# ---------------------------------------------------------------------
# Helper: Compute best-fit line (with R^2) for scatter plot data.
# ---------------------------------------------------------------------
def safe_compute_best_fit(x, y, log_x=False, log_y=True):
"""
Computes the best-fit line using linear regression.
Filters out NaN and non-positive values if log transforms are used.
Returns:
xs: sorted x values for plotting the line,
y_fit: predicted y values,
slope, intercept, and R^2.
If insufficient valid data exist, returns (None, None, NaN, NaN, NaN).
"""
x = np.array(x, dtype=float)
y = np.array(y, dtype=float)
valid_mask = ~np.isnan(x) & ~np.isnan(y)
if log_x:
valid_mask &= (x > 0)
if log_y:
valid_mask &= (y > 0)
x = x[valid_mask]
y = y[valid_mask]
if len(x) < 2:
return None, None, np.nan, np.nan, np.nan
if log_x:
X = np.log(x)
else:
X = x
if log_y:
Y = np.log(y)
else:
Y = y
slope, intercept = np.polyfit(X, Y, 1)
Y_pred = slope * X + intercept
if log_y:
y_pred = np.exp(Y_pred)
else:
y_pred = Y_pred
ss_res = np.sum((Y - Y_pred) ** 2)
ss_tot = np.sum((Y - np.mean(Y)) ** 2)
R2 = 1 - ss_res / ss_tot if ss_tot != 0 else np.nan
xs = np.linspace(np.min(x), np.max(x), 100)
if log_x:
X_line = np.log(xs)
else:
X_line = xs
Y_line = slope * X_line + intercept
if log_y:
y_fit = np.exp(Y_line)
else:
y_fit = Y_line
return xs, y_fit, slope, intercept, R2
# ---------------------------------------------------------------------
# Main loop: Process a condition folder (here, "small_perturbation")
# ---------------------------------------------------------------------
all_subdirs = sorted(os.listdir(condition_base_dir))
for cond_name in all_subdirs:
cond_path = os.path.join(condition_base_dir, cond_name)
if not os.path.isdir(cond_path):
continue
if cond_name not in analysis_conditions:
continue
print(f"\n=== Processing condition: {cond_name} ===")
desired_theta = analysis_conditions[cond_name][-1]
data_dir = os.path.join(cond_path, "data")
if not os.path.isdir(data_dir):
print(f"No 'data' folder for {cond_name}, skipping.")
continue
cond_data = {}
for fname in os.listdir(data_dir):
if fname.endswith(".json"):
loss_fn_name = fname.replace(".json", "")
with open(os.path.join(data_dir, fname), "r") as f:
cond_data[loss_fn_name] = json.load(f)
if "constant" not in cond_data:
print(f"No constant.json found for {cond_name}, skipping.")
continue
const_epochs = cond_data["constant"]["epochs"]
if convergence_ref_epoch not in const_epochs:
print(f"Ref epoch {convergence_ref_epoch} not in constant data for {cond_name}, skipping.")
continue
ref_idx = const_epochs.index(convergence_ref_epoch)
ref_theta = cond_data["constant"]["theta_over_epochs"][ref_idx]
time_array = cond_data["constant"]["time"]
final_const_loss = compute_constant_loss(ref_theta, time_array, desired_theta)
threshold = (1 + convergence_threshold_percentage/100) * final_const_loss
print(f"Final constant loss at ref epoch ({convergence_ref_epoch}): {final_const_loss:.6f}, threshold = {threshold:.6f}")
results = []
for fn_name in ordered_functions:
if fn_name not in cond_data:
continue
data_dict = cond_data[fn_name]
final_loss, _ = get_final_loss_and_convergence(data_dict, desired_theta, final_epoch, convergence_ref_epoch)
if final_loss is None:
continue
conv_epoch = find_convergence_epoch(data_dict, desired_theta, threshold)
t_median = metrics_dict[fn_name]["t_median"]
results.append({
"Function": fn_name,
"t_median": t_median,
"final_loss": final_loss,
"epochs_to_convergence": conv_epoch
})
df = pd.DataFrame(results)
if df.empty:
print(f"No valid data found for {cond_name}.")
continue
print(df.to_string(index=False))
# Create output folder for plots: condition/plots/centroid_convergence
plot_folder = os.path.join(cond_path, "plots", "centroid_convergence")
os.makedirs(plot_folder, exist_ok=True)
# ----------------------------
# Save composite plot using just t_median (Vertical Layout: 2 subplots)
# Top: Loss vs. t_median, Bottom: Epochs to Convergence vs. t_median
# ----------------------------
fig, axes = plt.subplots(2, 1, figsize=(8, 12))
# Subplot 1: t_median vs. final_loss
axes[0].scatter(df["t_median"], df["final_loss"], color="green")
axes[0].set_xlabel(r"$t_{median}$")
axes[0].set_ylabel(f"Loss at Epoch {final_epoch}")
axes[0].set_yscale("log")
axes[0].set_title(f"{cond_name}: Loss vs. $t_{{median}}$")
xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["t_median"], df["final_loss"], log_x=False, log_y=True)
if xs is not None:
axes[0].plot(xs, y_fit, "k--", label=f"Fit: slope={slope:.3f}, int={intercept:.3f}\n$R^2$={R2:.3f}")
axes[0].legend(fontsize=8)
texts = []
for i, row in df.iterrows():
if np.isfinite(row["t_median"]) and np.isfinite(row["final_loss"]):
txt = axes[0].text(row["t_median"], row["final_loss"], row["Function"], fontsize=8)
texts.append(txt)
adjust_text(texts, ax=axes[0], arrowprops=dict(arrowstyle="->", color="gray", lw=0.5))
# Subplot 2: t_median vs. epochs_to_convergence
axes[1].scatter(df["t_median"], df["epochs_to_convergence"], color="blue")
axes[1].set_xlabel(r"$t_{median}$")
axes[1].set_ylabel("Epochs to Convergence")
axes[1].set_yscale("log")
axes[1].set_title(f"{cond_name}: Convergence vs. $t_{{median}}$")
xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["t_median"], df["epochs_to_convergence"], log_x=False, log_y=True)
if xs is not None:
axes[1].plot(xs, y_fit, "k--", label=f"Fit: slope={slope:.3f}, int={intercept:.3f}\n$R^2$={R2:.3f}")
axes[1].legend(fontsize=8)
texts = []
for i, row in df.iterrows():
if np.isfinite(row["t_median"]) and np.isfinite(row["epochs_to_convergence"]):
txt = axes[1].text(row["t_median"], row["epochs_to_convergence"], row["Function"], fontsize=8)
texts.append(txt)
adjust_text(texts, ax=axes[1], arrowprops=dict(arrowstyle="->", color="gray", lw=0.5))
plt.tight_layout()
composite_plot_path = os.path.join(plot_folder, f"t_median_composite.png")
plt.savefig(composite_plot_path, dpi=300)
plt.close()
print(f"Saved t_median composite plot to {composite_plot_path}")