292 lines
12 KiB
Python
292 lines
12 KiB
Python
import os
|
|
import sys
|
|
import json
|
|
import torch
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
from adjustText import adjust_text # For automatic label adjustment
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Access your custom modules
|
|
# ---------------------------------------------------------------------
|
|
sys.path.append("/home/judson/Neural-Networks-in-GNC/inverted_pendulum")
|
|
from training.time_weighting_functions import weight_functions
|
|
from analysis.analysis_conditions import analysis_conditions
|
|
from numpy import pi
|
|
|
|
# ---------------------------------------------------------------------
|
|
# User parameters
|
|
# ---------------------------------------------------------------------
|
|
condition_base_dir = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/time_weighting"
|
|
final_epoch = 25 # The epoch at which to measure final loss
|
|
convergence_ref_epoch = 100 # The epoch used to define "final constant loss"
|
|
convergence_threshold_percentage = 10 # e.g., 10%
|
|
time_start, time_end, time_points = 0, 10, 1000
|
|
t_span = torch.linspace(time_start, time_end, time_points)
|
|
dt = t_span[1] - t_span[0]
|
|
|
|
# ---------------------------------------------------------------------
|
|
# 1) Compute weighting metrics (t_mean, t_median, R) for each function
|
|
# ---------------------------------------------------------------------
|
|
def compute_metrics(weight_fn, t_span, t_end, min_val=0.01):
|
|
"""
|
|
Returns (t_mean, t_median, R) for a given weighting function.
|
|
Here, R is computed as:
|
|
R = (integral from t_end/2 to t_end) / (integral from 0 to t_end/2),
|
|
so that a smaller R indicates more emphasis on early times.
|
|
"""
|
|
weights = weight_fn(t_span, t_max=t_end, min_val=min_val)
|
|
total_weight = torch.trapz(weights, t_span)
|
|
# Weighted mean
|
|
t_mean = torch.trapz(t_span * weights, t_span) / total_weight
|
|
|
|
# Weighted median
|
|
cum_integral = torch.cumsum(weights, dim=0) * dt
|
|
half_total = total_weight / 2.0
|
|
idx = (cum_integral >= half_total).nonzero()[0, 0]
|
|
t_median = t_span[idx].item()
|
|
|
|
# Inverse ratio: Late vs Early so that a lower value means more early weighting
|
|
early_mask = t_span <= (t_end / 2)
|
|
late_mask = t_span > (t_end / 2)
|
|
early_integral = torch.trapz(weights[early_mask], t_span[early_mask])
|
|
late_integral = torch.trapz(weights[late_mask], t_span[late_mask])
|
|
R = (late_integral / early_integral).item() if early_integral > 0 else np.nan
|
|
|
|
return t_mean.item(), t_median, R
|
|
|
|
# Ordered weighting functions
|
|
ordered_functions = [
|
|
"constant",
|
|
"linear", "linear_mirrored",
|
|
"quadratic", "quadratic_mirrored",
|
|
"cubic", "cubic_mirrored",
|
|
"square_root", "square_root_mirrored",
|
|
"cubic_root", "cubic_root_mirrored",
|
|
"inverse", "inverse_mirrored",
|
|
"inverse_squared", "inverse_squared_mirrored",
|
|
"inverse_cubed", "inverse_cubed_mirrored"
|
|
]
|
|
|
|
metrics_dict = {}
|
|
for fn_name in ordered_functions:
|
|
fn = weight_functions[fn_name]
|
|
t_mean, t_median, R = compute_metrics(fn, t_span, time_end, min_val=0.01)
|
|
metrics_dict[fn_name] = {"t_mean": t_mean, "t_median": t_median, "R": R}
|
|
|
|
# ---------------------------------------------------------------------
|
|
# 2) Helper to compute loss using constant weighting
|
|
# ---------------------------------------------------------------------
|
|
def compute_constant_loss(theta_array, time_array, desired_theta):
|
|
"""
|
|
Computes the mean squared error loss: mean((theta - desired_theta)^2).
|
|
(This is a simplified loss; adjust if you wish to include weighting.)
|
|
"""
|
|
theta_tensor = torch.tensor(theta_array, dtype=torch.float32)
|
|
desired_tensor = torch.full_like(theta_tensor, desired_theta)
|
|
loss_val = torch.mean((theta_tensor - desired_tensor)**2)
|
|
return loss_val.item()
|
|
|
|
# ---------------------------------------------------------------------
|
|
# 3) Parse each condition folder and gather final loss & convergence data
|
|
# ---------------------------------------------------------------------
|
|
def get_final_loss_and_convergence(data_dict, desired_theta, final_epoch, ref_epoch):
|
|
"""
|
|
data_dict: loaded JSON for a weighting function with keys:
|
|
"epochs", "theta_over_epochs" (list of arrays), "time" (array)
|
|
Returns final loss (computed at final_epoch) and placeholder None.
|
|
"""
|
|
epochs = data_dict["epochs"]
|
|
theta_over_epochs = data_dict["theta_over_epochs"]
|
|
time_array = data_dict["time"]
|
|
|
|
if final_epoch not in epochs:
|
|
return None, None
|
|
|
|
final_idx = epochs.index(final_epoch)
|
|
final_theta = theta_over_epochs[final_idx]
|
|
final_loss = compute_constant_loss(final_theta, time_array, desired_theta)
|
|
return final_loss, None
|
|
|
|
def find_convergence_epoch(data_dict, desired_theta, threshold):
|
|
"""
|
|
Returns the first epoch whose loss (computed with constant MSE) is <= threshold.
|
|
If none is found, returns np.nan.
|
|
"""
|
|
epochs = data_dict["epochs"]
|
|
theta_over_epochs = data_dict["theta_over_epochs"]
|
|
time_array = data_dict["time"]
|
|
|
|
for ep, thetas in zip(epochs, theta_over_epochs):
|
|
loss_val = compute_constant_loss(thetas, time_array, desired_theta)
|
|
if loss_val <= threshold:
|
|
return ep
|
|
return np.nan
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Helper: Compute best-fit line (with R^2) for scatter plot data.
|
|
# ---------------------------------------------------------------------
|
|
def safe_compute_best_fit(x, y, log_x=False, log_y=True):
|
|
"""
|
|
Computes the best-fit line using linear regression.
|
|
Filters out NaN and non-positive values if log transforms are used.
|
|
Returns:
|
|
xs: sorted x values for plotting the line,
|
|
y_fit: predicted y values,
|
|
slope, intercept, and R^2.
|
|
If insufficient valid data exist, returns (None, None, NaN, NaN, NaN).
|
|
"""
|
|
x = np.array(x, dtype=float)
|
|
y = np.array(y, dtype=float)
|
|
valid_mask = ~np.isnan(x) & ~np.isnan(y)
|
|
if log_x:
|
|
valid_mask &= (x > 0)
|
|
if log_y:
|
|
valid_mask &= (y > 0)
|
|
x = x[valid_mask]
|
|
y = y[valid_mask]
|
|
if len(x) < 2:
|
|
return None, None, np.nan, np.nan, np.nan
|
|
if log_x:
|
|
X = np.log(x)
|
|
else:
|
|
X = x
|
|
if log_y:
|
|
Y = np.log(y)
|
|
else:
|
|
Y = y
|
|
slope, intercept = np.polyfit(X, Y, 1)
|
|
Y_pred = slope * X + intercept
|
|
if log_y:
|
|
y_pred = np.exp(Y_pred)
|
|
else:
|
|
y_pred = Y_pred
|
|
ss_res = np.sum((Y - Y_pred) ** 2)
|
|
ss_tot = np.sum((Y - np.mean(Y)) ** 2)
|
|
R2 = 1 - ss_res / ss_tot if ss_tot != 0 else np.nan
|
|
xs = np.linspace(np.min(x), np.max(x), 100)
|
|
if log_x:
|
|
X_line = np.log(xs)
|
|
else:
|
|
X_line = xs
|
|
Y_line = slope * X_line + intercept
|
|
if log_y:
|
|
y_fit = np.exp(Y_line)
|
|
else:
|
|
y_fit = Y_line
|
|
return xs, y_fit, slope, intercept, R2
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Main loop: Process a condition folder (here, "small_perturbation")
|
|
# ---------------------------------------------------------------------
|
|
all_subdirs = sorted(os.listdir(condition_base_dir))
|
|
for cond_name in all_subdirs:
|
|
cond_path = os.path.join(condition_base_dir, cond_name)
|
|
if not os.path.isdir(cond_path):
|
|
continue
|
|
if cond_name not in analysis_conditions:
|
|
continue
|
|
|
|
print(f"\n=== Processing condition: {cond_name} ===")
|
|
desired_theta = analysis_conditions[cond_name][-1]
|
|
|
|
data_dir = os.path.join(cond_path, "data")
|
|
if not os.path.isdir(data_dir):
|
|
print(f"No 'data' folder for {cond_name}, skipping.")
|
|
continue
|
|
|
|
cond_data = {}
|
|
for fname in os.listdir(data_dir):
|
|
if fname.endswith(".json"):
|
|
loss_fn_name = fname.replace(".json", "")
|
|
with open(os.path.join(data_dir, fname), "r") as f:
|
|
cond_data[loss_fn_name] = json.load(f)
|
|
|
|
if "constant" not in cond_data:
|
|
print(f"No constant.json found for {cond_name}, skipping.")
|
|
continue
|
|
|
|
const_epochs = cond_data["constant"]["epochs"]
|
|
if convergence_ref_epoch not in const_epochs:
|
|
print(f"Ref epoch {convergence_ref_epoch} not in constant data for {cond_name}, skipping.")
|
|
continue
|
|
ref_idx = const_epochs.index(convergence_ref_epoch)
|
|
ref_theta = cond_data["constant"]["theta_over_epochs"][ref_idx]
|
|
time_array = cond_data["constant"]["time"]
|
|
final_const_loss = compute_constant_loss(ref_theta, time_array, desired_theta)
|
|
threshold = (1 + convergence_threshold_percentage/100) * final_const_loss
|
|
print(f"Final constant loss at ref epoch ({convergence_ref_epoch}): {final_const_loss:.6f}, threshold = {threshold:.6f}")
|
|
|
|
results = []
|
|
for fn_name in ordered_functions:
|
|
if fn_name not in cond_data:
|
|
continue
|
|
data_dict = cond_data[fn_name]
|
|
final_loss, _ = get_final_loss_and_convergence(data_dict, desired_theta, final_epoch, convergence_ref_epoch)
|
|
if final_loss is None:
|
|
continue
|
|
conv_epoch = find_convergence_epoch(data_dict, desired_theta, threshold)
|
|
t_median = metrics_dict[fn_name]["t_median"]
|
|
results.append({
|
|
"Function": fn_name,
|
|
"t_median": t_median,
|
|
"final_loss": final_loss,
|
|
"epochs_to_convergence": conv_epoch
|
|
})
|
|
df = pd.DataFrame(results)
|
|
if df.empty:
|
|
print(f"No valid data found for {cond_name}.")
|
|
continue
|
|
print(df.to_string(index=False))
|
|
|
|
# Create output folder for plots: condition/plots/centroid_convergence
|
|
plot_folder = os.path.join(cond_path, "plots", "centroid_convergence")
|
|
os.makedirs(plot_folder, exist_ok=True)
|
|
|
|
# ----------------------------
|
|
# Save composite plot using just t_median (Vertical Layout: 2 subplots)
|
|
# Top: Loss vs. t_median, Bottom: Epochs to Convergence vs. t_median
|
|
# ----------------------------
|
|
fig, axes = plt.subplots(2, 1, figsize=(8, 12))
|
|
|
|
# Subplot 1: t_median vs. final_loss
|
|
axes[0].scatter(df["t_median"], df["final_loss"], color="green")
|
|
axes[0].set_xlabel(r"$t_{median}$")
|
|
axes[0].set_ylabel(f"Loss at Epoch {final_epoch}")
|
|
axes[0].set_yscale("log")
|
|
axes[0].set_title(f"{cond_name}: Loss vs. $t_{{median}}$")
|
|
xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["t_median"], df["final_loss"], log_x=False, log_y=True)
|
|
if xs is not None:
|
|
axes[0].plot(xs, y_fit, "k--", label=f"Fit: slope={slope:.3f}, int={intercept:.3f}\n$R^2$={R2:.3f}")
|
|
axes[0].legend(fontsize=8)
|
|
texts = []
|
|
for i, row in df.iterrows():
|
|
if np.isfinite(row["t_median"]) and np.isfinite(row["final_loss"]):
|
|
txt = axes[0].text(row["t_median"], row["final_loss"], row["Function"], fontsize=8)
|
|
texts.append(txt)
|
|
adjust_text(texts, ax=axes[0], arrowprops=dict(arrowstyle="->", color="gray", lw=0.5))
|
|
|
|
# Subplot 2: t_median vs. epochs_to_convergence
|
|
axes[1].scatter(df["t_median"], df["epochs_to_convergence"], color="blue")
|
|
axes[1].set_xlabel(r"$t_{median}$")
|
|
axes[1].set_ylabel("Epochs to Convergence")
|
|
axes[1].set_yscale("log")
|
|
axes[1].set_title(f"{cond_name}: Convergence vs. $t_{{median}}$")
|
|
xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["t_median"], df["epochs_to_convergence"], log_x=False, log_y=True)
|
|
if xs is not None:
|
|
axes[1].plot(xs, y_fit, "k--", label=f"Fit: slope={slope:.3f}, int={intercept:.3f}\n$R^2$={R2:.3f}")
|
|
axes[1].legend(fontsize=8)
|
|
texts = []
|
|
for i, row in df.iterrows():
|
|
if np.isfinite(row["t_median"]) and np.isfinite(row["epochs_to_convergence"]):
|
|
txt = axes[1].text(row["t_median"], row["epochs_to_convergence"], row["Function"], fontsize=8)
|
|
texts.append(txt)
|
|
adjust_text(texts, ax=axes[1], arrowprops=dict(arrowstyle="->", color="gray", lw=0.5))
|
|
|
|
plt.tight_layout()
|
|
composite_plot_path = os.path.join(plot_folder, f"t_median_composite.png")
|
|
plt.savefig(composite_plot_path, dpi=300)
|
|
plt.close()
|
|
print(f"Saved t_median composite plot to {composite_plot_path}")
|