269 lines
10 KiB
Python
269 lines
10 KiB
Python
import os
|
|
import sys
|
|
import json
|
|
import math
|
|
import torch
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
from adjustText import adjust_text
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Access your custom modules
|
|
# ---------------------------------------------------------------------
|
|
sys.path.append("/home/judson/Neural-Networks-in-GNC/inverted_pendulum")
|
|
from training.base_loss_functions import base_loss_functions
|
|
from analysis.analysis_conditions import analysis_conditions
|
|
|
|
# ---------------------------------------------------------------------
|
|
# User parameters
|
|
# ---------------------------------------------------------------------
|
|
condition_base_dir = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/base_loss"
|
|
final_epoch = 25 # The epoch at which to measure final loss
|
|
convergence_ref_epoch = 100 # The epoch used to define "final constant loss" (analogous)
|
|
convergence_threshold_percentage = 10 # e.g., 10%
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Helper: replicate the base loss for a single epoch
|
|
# ---------------------------------------------------------------------
|
|
def replicate_base_loss(theta_array, desired_theta, base_key):
|
|
"""
|
|
Given:
|
|
- theta_array: array of theta values for one epoch
|
|
- desired_theta: target angle (scalar)
|
|
- base_key: key in base_loss_functions (e.g., "one", "one_fifth", etc.)
|
|
|
|
Computes the mean of the base loss function over all theta values.
|
|
"""
|
|
theta_tensor = torch.tensor(theta_array, dtype=torch.float32)
|
|
loss_fn = base_loss_functions[base_key][1]
|
|
loss_val = torch.mean(loss_fn(theta_tensor, torch.tensor(desired_theta, dtype=torch.float32), min_val=0.01))
|
|
return loss_val.item()
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Helper: safe regression function (handles log transforms, filters invalid data)
|
|
# ---------------------------------------------------------------------
|
|
def safe_compute_best_fit(x, y, log_x=False, log_y=True):
|
|
"""
|
|
Computes a best-fit line using linear regression, optionally on log(x) and/or log(y).
|
|
Returns:
|
|
xs, y_fit, slope, intercept, R2
|
|
If insufficient valid data, returns (None, None, np.nan, np.nan, np.nan).
|
|
"""
|
|
x = np.array(x, dtype=float)
|
|
y = np.array(y, dtype=float)
|
|
|
|
# Filter out NaNs
|
|
valid_mask = ~np.isnan(x) & ~np.isnan(y)
|
|
if log_x:
|
|
valid_mask &= (x > 0)
|
|
if log_y:
|
|
valid_mask &= (y > 0)
|
|
x = x[valid_mask]
|
|
y = y[valid_mask]
|
|
|
|
if len(x) < 2:
|
|
return None, None, np.nan, np.nan, np.nan
|
|
|
|
if log_x:
|
|
X = np.log(x)
|
|
else:
|
|
X = x
|
|
if log_y:
|
|
Y = np.log(y)
|
|
else:
|
|
Y = y
|
|
|
|
slope, intercept = np.polyfit(X, Y, 1)
|
|
Y_pred = slope * X + intercept
|
|
if log_y:
|
|
y_pred = np.exp(Y_pred)
|
|
else:
|
|
y_pred = Y_pred
|
|
|
|
ss_res = np.sum((Y - Y_pred)**2)
|
|
ss_tot = np.sum((Y - np.mean(Y))**2)
|
|
R2 = 1 - ss_res / ss_tot if ss_tot != 0 else np.nan
|
|
|
|
# Build line for plotting
|
|
xs = np.linspace(np.min(x), np.max(x), 100)
|
|
if log_x:
|
|
X_line = np.log(xs)
|
|
else:
|
|
X_line = xs
|
|
Y_line = slope * X_line + intercept
|
|
if log_y:
|
|
y_fit = np.exp(Y_line)
|
|
else:
|
|
y_fit = Y_line
|
|
|
|
return xs, y_fit, slope, intercept, R2
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Main function to process a single condition
|
|
# ---------------------------------------------------------------------
|
|
def process_condition(condition_name):
|
|
"""
|
|
Loads data for each base loss function (like one_fifth.json, three.json, etc.) from
|
|
the condition's 'data' folder, computes final loss at final_epoch, and finds convergence
|
|
epoch using a threshold based on the 'one' base function. Then plots exponent vs. final_loss
|
|
and exponent vs. epochs_to_convergence.
|
|
"""
|
|
# Condition folder
|
|
cond_path = os.path.join(condition_base_dir, condition_name)
|
|
data_dir = os.path.join(cond_path, "data")
|
|
if not os.path.isdir(data_dir):
|
|
print(f"No 'data' folder for {condition_name}, skipping.")
|
|
return
|
|
|
|
if condition_name not in analysis_conditions:
|
|
print(f"Condition '{condition_name}' not in analysis_conditions, skipping.")
|
|
return
|
|
|
|
desired_theta = analysis_conditions[condition_name][-1]
|
|
|
|
# Load JSON data
|
|
cond_data = {}
|
|
for fname in os.listdir(data_dir):
|
|
if fname.endswith(".json"):
|
|
base_key = fname.replace(".json", "") # e.g. "one_fifth"
|
|
with open(os.path.join(data_dir, fname), "r") as f:
|
|
cond_data[base_key] = json.load(f)
|
|
|
|
# Make sure the "reference" function ("one") exists
|
|
if "one" not in cond_data:
|
|
print(f"No 'one.json' found for {condition_name}, skipping.")
|
|
return
|
|
|
|
# Compute final loss for 'one' at convergence_ref_epoch
|
|
# => This is our reference for the threshold
|
|
ref_epochs = cond_data["one"]["epochs"]
|
|
if convergence_ref_epoch not in ref_epochs:
|
|
print(f"Ref epoch {convergence_ref_epoch} not in 'one' data, skipping.")
|
|
return
|
|
|
|
ref_idx = ref_epochs.index(convergence_ref_epoch)
|
|
ref_theta = cond_data["one"]["theta_over_epochs"][ref_idx]
|
|
# replicate the base loss for "one"
|
|
final_one_loss = replicate_base_loss(ref_theta, desired_theta, "one")
|
|
threshold = (1 + convergence_threshold_percentage/100) * final_one_loss
|
|
|
|
print(f"\n=== Condition: {condition_name} ===")
|
|
print(f"Final 'one' loss at epoch {convergence_ref_epoch} = {final_one_loss:.6f}, threshold = {threshold:.6f}")
|
|
|
|
# We'll store results in a DataFrame
|
|
results = []
|
|
# The order in base_loss_functions dictionary isn't guaranteed, so we create
|
|
# a sorted list by exponent
|
|
sorted_keys = sorted(base_loss_functions.keys(), key=lambda k: base_loss_functions[k][0])
|
|
|
|
# Helper: find the first epoch where the base loss is <= threshold
|
|
def find_convergence_epoch(base_key, data_dict, desired_theta, threshold):
|
|
epochs = data_dict["epochs"]
|
|
theta_over_epochs = data_dict["theta_over_epochs"]
|
|
for ep, thetas in zip(epochs, theta_over_epochs):
|
|
val = replicate_base_loss(thetas, desired_theta, base_key)
|
|
if val <= threshold:
|
|
return ep
|
|
return np.nan
|
|
|
|
for base_key in sorted_keys:
|
|
if base_key not in cond_data:
|
|
# JSON not found for this base function
|
|
continue
|
|
|
|
data_dict = cond_data[base_key]
|
|
# final_epoch check
|
|
if final_epoch not in data_dict["epochs"]:
|
|
continue
|
|
final_idx = data_dict["epochs"].index(final_epoch)
|
|
final_theta = data_dict["theta_over_epochs"][final_idx]
|
|
final_loss_val = replicate_base_loss(final_theta, desired_theta, "one")
|
|
|
|
conv_epoch = find_convergence_epoch("one", data_dict, desired_theta, threshold)
|
|
exponent_val = base_loss_functions[base_key][0] # e.g. 1/5, 3, 4, etc.
|
|
|
|
results.append({
|
|
"Function": base_key,
|
|
"Exponent": exponent_val,
|
|
"final_loss": final_loss_val,
|
|
"epochs_to_convergence": conv_epoch
|
|
})
|
|
|
|
df = pd.DataFrame(results)
|
|
if df.empty:
|
|
print("No valid data found.")
|
|
return
|
|
|
|
print(df.to_string(index=False))
|
|
|
|
# Create output folder
|
|
plot_folder = os.path.join(cond_path, "plots", "degree_convergence")
|
|
os.makedirs(plot_folder, exist_ok=True)
|
|
|
|
# -----------------------------------------------------------------
|
|
# Plot #1: final_loss vs. exponent
|
|
# -----------------------------------------------------------------
|
|
fig, axes = plt.subplots(2, 1, figsize=(8, 12))
|
|
|
|
# Top subplot: final_loss vs. exponent
|
|
ax_top = axes[0]
|
|
ax_top.scatter(df["Exponent"], df["final_loss"], color="blue")
|
|
ax_top.set_xscale("linear")
|
|
ax_top.set_yscale("log")
|
|
ax_top.set_xlabel("Exponent")
|
|
ax_top.set_ylabel(f"Final Loss @ epoch {final_epoch}")
|
|
ax_top.set_title(f"{condition_name}: Final Loss vs. Exponent")
|
|
# best-fit line
|
|
xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["Exponent"], df["final_loss"], log_x=False, log_y=True)
|
|
if xs is not None:
|
|
ax_top.plot(xs, y_fit, "k--", label=f"slope={slope:.3f}, int={intercept:.3f}, R²={R2:.3f}")
|
|
ax_top.legend(fontsize=8)
|
|
# label points with adjustText
|
|
texts = []
|
|
for i, row in df.iterrows():
|
|
if np.isfinite(row["Exponent"]) and np.isfinite(row["final_loss"]):
|
|
txt = ax_top.text(row["Exponent"], row["final_loss"], row["Function"], fontsize=8)
|
|
texts.append(txt)
|
|
adjust_text(texts, ax=ax_top, arrowprops=dict(arrowstyle="->", color="gray", lw=0.5))
|
|
|
|
# Bottom subplot: epochs_to_convergence vs. exponent
|
|
ax_bot = axes[1]
|
|
ax_bot.scatter(df["Exponent"], df["epochs_to_convergence"], color="green")
|
|
ax_bot.set_xscale("linear")
|
|
ax_bot.set_yscale("log")
|
|
ax_bot.set_xlabel("Exponent")
|
|
ax_bot.set_ylabel("Epochs to Convergence")
|
|
ax_bot.set_title(f"{condition_name}: Convergence vs. Exponent")
|
|
# best-fit line
|
|
xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["Exponent"], df["epochs_to_convergence"], log_x=False, log_y=True)
|
|
if xs is not None:
|
|
ax_bot.plot(xs, y_fit, "k--", label=f"slope={slope:.3f}, int={intercept:.3f}, R²={R2:.3f}")
|
|
ax_bot.legend(fontsize=8)
|
|
# label points with adjustText
|
|
texts = []
|
|
for i, row in df.iterrows():
|
|
if np.isfinite(row["Exponent"]) and np.isfinite(row["epochs_to_convergence"]):
|
|
txt = ax_bot.text(row["Exponent"], row["epochs_to_convergence"], row["Function"], fontsize=8)
|
|
texts.append(txt)
|
|
adjust_text(texts, ax=ax_bot, arrowprops=dict(arrowstyle="->", color="gray", lw=0.5))
|
|
|
|
plt.tight_layout()
|
|
save_path = os.path.join(plot_folder, f"exponent_vs_loss_and_convergence.png")
|
|
plt.savefig(save_path, dpi=300)
|
|
plt.close()
|
|
print(f"Saved base loss exponent analysis to {save_path}")
|
|
|
|
# ---------------------------------------------------------------------
|
|
# If you want to loop over multiple conditions, do so here:
|
|
# ---------------------------------------------------------------------
|
|
if __name__ == "__main__":
|
|
all_subdirs = sorted(os.listdir(condition_base_dir))
|
|
for cond_name in all_subdirs:
|
|
cond_path = os.path.join(condition_base_dir, cond_name)
|
|
if not os.path.isdir(cond_path):
|
|
continue
|
|
if cond_name not in analysis_conditions:
|
|
continue
|
|
process_condition(cond_name)
|