Inverted-Pendulum-Neural-Ne.../analysis/base_loss/centroid_convergence_plotter.py

269 lines
10 KiB
Python

import os
import sys
import json
import math
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from adjustText import adjust_text
# ---------------------------------------------------------------------
# Access your custom modules
# ---------------------------------------------------------------------
sys.path.append("/home/judson/Neural-Networks-in-GNC/inverted_pendulum")
from training.base_loss_functions import base_loss_functions
from analysis.analysis_conditions import analysis_conditions
# ---------------------------------------------------------------------
# User parameters
# ---------------------------------------------------------------------
condition_base_dir = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/base_loss"
final_epoch = 25 # The epoch at which to measure final loss
convergence_ref_epoch = 100 # The epoch used to define "final constant loss" (analogous)
convergence_threshold_percentage = 10 # e.g., 10%
# ---------------------------------------------------------------------
# Helper: replicate the base loss for a single epoch
# ---------------------------------------------------------------------
def replicate_base_loss(theta_array, desired_theta, base_key):
"""
Given:
- theta_array: array of theta values for one epoch
- desired_theta: target angle (scalar)
- base_key: key in base_loss_functions (e.g., "one", "one_fifth", etc.)
Computes the mean of the base loss function over all theta values.
"""
theta_tensor = torch.tensor(theta_array, dtype=torch.float32)
loss_fn = base_loss_functions[base_key][1]
loss_val = torch.mean(loss_fn(theta_tensor, torch.tensor(desired_theta, dtype=torch.float32), min_val=0.01))
return loss_val.item()
# ---------------------------------------------------------------------
# Helper: safe regression function (handles log transforms, filters invalid data)
# ---------------------------------------------------------------------
def safe_compute_best_fit(x, y, log_x=False, log_y=False):
"""
Computes a best-fit line using linear regression, optionally on log(x) and/or log(y).
Returns:
xs, y_fit, slope, intercept, R2
If insufficient valid data, returns (None, None, np.nan, np.nan, np.nan).
"""
x = np.array(x, dtype=float)
y = np.array(y, dtype=float)
# Filter out NaNs
valid_mask = ~np.isnan(x) & ~np.isnan(y)
if log_x:
valid_mask &= (x > 0)
if log_y:
valid_mask &= (y > 0)
x = x[valid_mask]
y = y[valid_mask]
if len(x) < 2:
return None, None, np.nan, np.nan, np.nan
if log_x:
X = np.log(x)
else:
X = x
if log_y:
Y = np.log(y)
else:
Y = y
slope, intercept = np.polyfit(X, Y, 1)
Y_pred = slope * X + intercept
if log_y:
y_pred = np.exp(Y_pred)
else:
y_pred = Y_pred
ss_res = np.sum((Y - Y_pred)**2)
ss_tot = np.sum((Y - np.mean(Y))**2)
R2 = 1 - ss_res / ss_tot if ss_tot != 0 else np.nan
# Build line for plotting
xs = np.linspace(np.min(x), np.max(x), 100)
if log_x:
X_line = np.log(xs)
else:
X_line = xs
Y_line = slope * X_line + intercept
if log_y:
y_fit = np.exp(Y_line)
else:
y_fit = Y_line
return xs, y_fit, slope, intercept, R2
# ---------------------------------------------------------------------
# Main function to process a single condition
# ---------------------------------------------------------------------
def process_condition(condition_name):
"""
Loads data for each base loss function (like one_fifth.json, three.json, etc.) from
the condition's 'data' folder, computes final loss at final_epoch, and finds convergence
epoch using a threshold based on the 'one' base function. Then plots exponent vs. final_loss
and exponent vs. epochs_to_convergence.
"""
# Condition folder
cond_path = os.path.join(condition_base_dir, condition_name)
data_dir = os.path.join(cond_path, "data")
if not os.path.isdir(data_dir):
print(f"No 'data' folder for {condition_name}, skipping.")
return
if condition_name not in analysis_conditions:
print(f"Condition '{condition_name}' not in analysis_conditions, skipping.")
return
desired_theta = analysis_conditions[condition_name][-1]
# Load JSON data
cond_data = {}
for fname in os.listdir(data_dir):
if fname.endswith(".json"):
base_key = fname.replace(".json", "") # e.g. "one_fifth"
with open(os.path.join(data_dir, fname), "r") as f:
cond_data[base_key] = json.load(f)
# Make sure the "reference" function ("one") exists
if "one" not in cond_data:
print(f"No 'one.json' found for {condition_name}, skipping.")
return
# Compute final loss for 'one' at convergence_ref_epoch
# => This is our reference for the threshold
ref_epochs = cond_data["one"]["epochs"]
if convergence_ref_epoch not in ref_epochs:
print(f"Ref epoch {convergence_ref_epoch} not in 'one' data, skipping.")
return
ref_idx = ref_epochs.index(convergence_ref_epoch)
ref_theta = cond_data["one"]["theta_over_epochs"][ref_idx]
# replicate the base loss for "one"
final_one_loss = replicate_base_loss(ref_theta, desired_theta, "one")
threshold = (1 + convergence_threshold_percentage/100) * final_one_loss
print(f"\n=== Condition: {condition_name} ===")
print(f"Final 'one' loss at epoch {convergence_ref_epoch} = {final_one_loss:.6f}, threshold = {threshold:.6f}")
# We'll store results in a DataFrame
results = []
# The order in base_loss_functions dictionary isn't guaranteed, so we create
# a sorted list by exponent
sorted_keys = sorted(base_loss_functions.keys(), key=lambda k: base_loss_functions[k][0])
# Helper: find the first epoch where the base loss is <= threshold
def find_convergence_epoch(base_key, data_dict, desired_theta, threshold):
epochs = data_dict["epochs"]
theta_over_epochs = data_dict["theta_over_epochs"]
for ep, thetas in zip(epochs, theta_over_epochs):
val = replicate_base_loss(thetas, desired_theta, base_key)
if val <= threshold:
return ep
return np.nan
for base_key in sorted_keys:
if base_key not in cond_data:
# JSON not found for this base function
continue
data_dict = cond_data[base_key]
# final_epoch check
if final_epoch not in data_dict["epochs"]:
continue
final_idx = data_dict["epochs"].index(final_epoch)
final_theta = data_dict["theta_over_epochs"][final_idx]
final_loss_val = replicate_base_loss(final_theta, desired_theta, base_key)
conv_epoch = find_convergence_epoch(base_key, data_dict, desired_theta, threshold)
exponent_val = base_loss_functions[base_key][0] # e.g. 1/5, 3, 4, etc.
results.append({
"Function": base_key,
"Exponent": exponent_val,
"final_loss": final_loss_val,
"epochs_to_convergence": conv_epoch
})
df = pd.DataFrame(results)
if df.empty:
print("No valid data found.")
return
print(df.to_string(index=False))
# Create output folder
plot_folder = os.path.join(cond_path, "plots", "degree_convergence")
os.makedirs(plot_folder, exist_ok=True)
# -----------------------------------------------------------------
# Plot #1: final_loss vs. exponent
# -----------------------------------------------------------------
fig, axes = plt.subplots(2, 1, figsize=(8, 12))
# Top subplot: final_loss vs. exponent
ax_top = axes[0]
ax_top.scatter(df["Exponent"], df["final_loss"], color="blue")
ax_top.set_xscale("linear")
ax_top.set_yscale("log")
ax_top.set_xlabel("Exponent")
ax_top.set_ylabel(f"Final Loss @ epoch {final_epoch}")
ax_top.set_title(f"{condition_name}: Final Loss vs. Exponent")
# best-fit line
xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["Exponent"], df["final_loss"], log_x=False, log_y=False)
if xs is not None:
ax_top.plot(xs, y_fit, "k--", label=f"slope={slope:.3f}, int={intercept:.3f}, R²={R2:.3f}")
ax_top.legend(fontsize=8)
# label points with adjustText
texts = []
for i, row in df.iterrows():
if np.isfinite(row["Exponent"]) and np.isfinite(row["final_loss"]):
txt = ax_top.text(row["Exponent"], row["final_loss"], row["Function"], fontsize=8)
texts.append(txt)
adjust_text(texts, ax=ax_top, arrowprops=dict(arrowstyle="->", color="gray", lw=0.5))
# Bottom subplot: epochs_to_convergence vs. exponent
ax_bot = axes[1]
ax_bot.scatter(df["Exponent"], df["epochs_to_convergence"], color="green")
ax_bot.set_xscale("linear")
ax_bot.set_yscale("log")
ax_bot.set_xlabel("Exponent")
ax_bot.set_ylabel("Epochs to Convergence")
ax_bot.set_title(f"{condition_name}: Convergence vs. Exponent")
# best-fit line
xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["Exponent"], df["epochs_to_convergence"], log_x=False, log_y=False)
if xs is not None:
ax_bot.plot(xs, y_fit, "k--", label=f"slope={slope:.3f}, int={intercept:.3f}, R²={R2:.3f}")
ax_bot.legend(fontsize=8)
# label points with adjustText
texts = []
for i, row in df.iterrows():
if np.isfinite(row["Exponent"]) and np.isfinite(row["epochs_to_convergence"]):
txt = ax_bot.text(row["Exponent"], row["epochs_to_convergence"], row["Function"], fontsize=8)
texts.append(txt)
adjust_text(texts, ax=ax_bot, arrowprops=dict(arrowstyle="->", color="gray", lw=0.5))
plt.tight_layout()
save_path = os.path.join(plot_folder, f"exponent_vs_loss_and_convergence.png")
plt.savefig(save_path, dpi=300)
plt.close()
print(f"Saved base loss exponent analysis to {save_path}")
# ---------------------------------------------------------------------
# If you want to loop over multiple conditions, do so here:
# ---------------------------------------------------------------------
if __name__ == "__main__":
all_subdirs = sorted(os.listdir(condition_base_dir))
for cond_name in all_subdirs:
cond_path = os.path.join(condition_base_dir, cond_name)
if not os.path.isdir(cond_path):
continue
if cond_name not in analysis_conditions:
continue
process_condition(cond_name)