Inverted-Pendulum-Neural-Ne.../analysis/base_loss/generate_convergence_plots.py

242 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import sys
sys.path.append("/home/judson/Neural-Networks-in-GNC/inverted_pendulum")
from training.base_loss_functions import base_loss_functions
from analysis.analysis_conditions import analysis_conditions
###############################################################################
# Helper: Replicate base loss for one epoch.
# Loss = mean( f(theta) ) using the specified base loss function.
###############################################################################
def replicate_base_loss(theta_array, time_array, desired_theta, base_key):
"""
Given:
- theta_array: array of theta values for one epoch.
- time_array: array of time values (same length as theta_array)
- desired_theta: target angle (scalar)
- base_key: key in base_loss_functions (e.g., "one", "one_fifth", etc.)
Computes the loss for that epoch by applying the corresponding base loss function.
Returns the mean loss over time.
"""
theta_tensor = torch.tensor(theta_array, dtype=torch.float32)
loss_fn = base_loss_functions[base_key][1]
# The loss function returns a tensor; we take the mean and convert to a scalar.
loss_val = torch.mean(loss_fn(theta_tensor, torch.tensor(desired_theta, dtype=torch.float32), min_val=0.01))
return loss_val.item()
###############################################################################
# Helper: Filter epoch data by range and sampling step.
###############################################################################
def filter_epoch_data(epochs, theta_over_epochs, epoch_range, epoch_step):
filtered_epochs = []
filtered_theta_over_epochs = []
for i, ep in enumerate(epochs):
if epoch_range[0] <= ep <= epoch_range[1]:
filtered_epochs.append(ep)
filtered_theta_over_epochs.append(theta_over_epochs[i])
if epoch_step > 1:
filtered_epochs = filtered_epochs[::epoch_step]
filtered_theta_over_epochs = filtered_theta_over_epochs[::epoch_step]
return filtered_epochs, filtered_theta_over_epochs
###############################################################################
# Composite plotting function (renamed to plot_base_sweep_convergence)
###############################################################################
def plot_base_sweep_convergence(cond_results, condition_name, desired_theta,
left_list, right_list, save_path, plot_epoch_range, plot_epoch_step,
normalization_mode="raw"):
"""
Creates a composite 2-column loss convergence plot on a semilogy scale using base loss functions.
Layout:
- Top row (spanning 2 columns): Plots the loss computed on the "ones" data (our base).
- Each subsequent row pairs left_list[i] (left column) with right_list[i] (right column).
In each subplot, two curves are plotted:
* The loss computed using "ones" (base) applied to that network's trajectory.
* The loss computed using the specific base loss function.
Normalization modes:
- "raw": Plot raw loss values.
- "norm_const": Plot only the base ("ones") curve, normalized by its final loss from the top row.
- "norm_both": Plot both curves, each normalized by its own final value.
The xaxis is linear (epochs) and the yaxis uses semilogy.
"""
total_rows = 1 + len(left_list)
fig = plt.figure(figsize=(12, 3 * total_rows))
gs = gridspec.GridSpec(total_rows, 2)
# --- Top row: "ones" data ---
ax_top = fig.add_subplot(gs[0, :])
# "ones" is our base weighting function.
epochs_ones = cond_results["one"]["epochs"]
theta_ones = cond_results["one"]["theta_over_epochs"]
time_ones = cond_results["one"]["time"]
epochs_ones, theta_ones = filter_epoch_data(epochs_ones, theta_ones, plot_epoch_range, plot_epoch_step)
ones_losses = [replicate_base_loss(theta_arr, time_ones, desired_theta, "one")
for theta_arr in theta_ones]
epochs_ones_arr = np.array(epochs_ones)
final_base_loss = ones_losses[-1] if ones_losses and ones_losses[-1] != 0 else 1.0
if normalization_mode == "raw":
y_ones = ones_losses
top_title = "Unnormalized Loss"
elif normalization_mode == "norm_const":
y_ones = [val / final_base_loss for val in ones_losses]
top_title = "Loss Normalized by Final Base Loss"
elif normalization_mode == "norm_both":
y_ones = [val / final_base_loss for val in ones_losses]
top_title = "Loss Normalized by Final Loss"
else:
y_ones = ones_losses
top_title = "Loss Convergence"
ax_top.semilogy(epochs_ones_arr, y_ones, label="ones Degree Polynomial", color="black", linestyle="--")
ax_top.set_title(f"Initial Condition: {condition_name} {top_title}")
ax_top.set_ylabel("Loss")
ax_top.legend(fontsize=10)
# --- Subsequent rows: each row uses left_list and right_list ---
for i in range(len(left_list)):
left_key = left_list[i]
right_key = right_list[i]
# Left subplot:
ax_left = fig.add_subplot(gs[i+1, 0])
if left_key in cond_results:
epochs_left = cond_results[left_key]["epochs"]
theta_left = cond_results[left_key]["theta_over_epochs"]
time_left = cond_results[left_key]["time"]
epochs_left, theta_left = filter_epoch_data(epochs_left, theta_left, plot_epoch_range, plot_epoch_step)
left_base_losses = [replicate_base_loss(theta_arr, time_left, desired_theta, "one")
for theta_arr in theta_left]
left_method_losses = [replicate_base_loss(theta_arr, time_left, desired_theta, left_key)
for theta_arr in theta_left]
epochs_left_arr = np.array(epochs_left)
if normalization_mode == "raw":
y_left_base = left_base_losses
y_left_method = left_method_losses
label_left_base = "ones Degree Polynomial"
label_left_method = f"{left_key} Degree Polynomial"
elif normalization_mode == "norm_const":
y_left_base = [val / final_base_loss for val in left_base_losses]
y_left_method = None
label_left_base = "ones Degree Polynomial (Norm)"
elif normalization_mode == "norm_both":
norm_left_base = left_base_losses[-1] if left_base_losses[-1] != 0 else 1.0
norm_left_method = left_method_losses[-1] if left_method_losses[-1] != 0 else 1.0
y_left_base = [val / norm_left_base for val in left_base_losses]
y_left_method = [val / norm_left_method for val in left_method_losses]
label_left_base = "ones Degree Polynomial (Norm)"
label_left_method = f"{left_key} Degree Polynomial (Norm)"
ax_left.semilogy(epochs_left_arr, y_left_base, label=label_left_base, color="black", linestyle="--")
if y_left_method is not None:
ax_left.semilogy(epochs_left_arr, y_left_method, label=label_left_method, color="blue")
ax_left.set_ylabel("Loss")
ax_left.set_title(f"{left_key} Loss Convergence")
ax_left.legend(fontsize=10)
else:
ax_left.set_title(f"No Data for {left_key}")
# Right subplot:
ax_right = fig.add_subplot(gs[i+1, 1])
if right_key in cond_results:
epochs_right = cond_results[right_key]["epochs"]
theta_right = cond_results[right_key]["theta_over_epochs"]
time_right = cond_results[right_key]["time"]
epochs_right, theta_right = filter_epoch_data(epochs_right, theta_right, plot_epoch_range, plot_epoch_step)
right_base_losses = [replicate_base_loss(theta_arr, time_right, desired_theta, "one")
for theta_arr in theta_right]
right_method_losses = [replicate_base_loss(theta_arr, time_right, desired_theta, right_key)
for theta_arr in theta_right]
epochs_right_arr = np.array(epochs_right)
if normalization_mode == "raw":
y_right_base = right_base_losses
y_right_method = right_method_losses
label_right_base = "ones Degree Polynomial"
label_right_method = f"{right_key} Degree Polynomial"
elif normalization_mode == "norm_const":
y_right_base = [val / final_base_loss for val in right_base_losses]
y_right_method = None
label_right_base = "ones Degree Polynomial (Norm)"
elif normalization_mode == "norm_both":
norm_right_base = right_base_losses[-1] if right_base_losses[-1] != 0 else 1.0
norm_right_method = right_method_losses[-1] if right_method_losses[-1] != 0 else 1.0
y_right_base = [val / norm_right_base for val in right_base_losses]
y_right_method = [val / norm_right_method for val in right_method_losses]
label_right_base = "ones Degree Polynomial (Norm)"
label_right_method = f"{right_key} Degree Polynomial (Norm)"
ax_right.semilogy(epochs_right_arr, y_right_base, label=label_right_base, color="black", linestyle="--")
if y_right_method is not None:
ax_right.semilogy(epochs_right_arr, y_right_method, label=label_right_method, color="green")
ax_right.set_title(f"{right_key} Loss Convergence")
ax_right.legend(fontsize=10)
else:
ax_right.set_title(f"No Data for {right_key}")
# Force the two subplots in this row to share the same y-axis limits.
left_ylim = ax_left.get_ylim()
right_ylim = ax_right.get_ylim()
common_ylim = (min(left_ylim[0], right_ylim[0]), max(left_ylim[1], right_ylim[1]))
ax_left.set_ylim(common_ylim)
ax_right.set_ylim(common_ylim)
for ax in fig.get_axes():
ax.set_xlabel("Epoch")
plt.tight_layout()
plt.savefig(save_path, dpi=300)
plt.close()
print(f"Saved base-sweep composite plot to {save_path}")
###############################################################################
# Main plotting loop
###############################################################################
if __name__ == "__main__":
# Directory where each condition folder resides
output_dir = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/base_loss"
# Settings for convergence plots
plot_epoch_range = (0, 25)
plot_epoch_step = 1
# Define left and right lists.
# Left: fractional loss functions, e.g. one_fifth, one_fourth, one_third, one_half
# Right: multiple loss functions, e.g. two, three, four, five
left_list = ["one_fifth", "one_fourth", "one_third", "one_half"]
right_list = ["two", "three", "four", "five"]
from analysis.analysis_conditions import analysis_conditions
condition_names = [name for name in os.listdir(output_dir)
if os.path.isdir(os.path.join(output_dir, name)) and name in analysis_conditions]
for condition in condition_names:
print(f"Processing condition: {condition}")
desired_theta = analysis_conditions[condition][-1]
data_dir = os.path.join(output_dir, condition, "data")
cond_results = {}
for file in os.listdir(data_dir):
if file.endswith(".json"):
loss_fn = file.replace(".json", "")
file_path = os.path.join(data_dir, file)
with open(file_path, "r") as f:
cond_results[loss_fn] = json.load(f)
plots_dir = os.path.join(output_dir, condition, "plots", "loss_convergence", f"{plot_epoch_range[1]}_epochs")
os.makedirs(plots_dir, exist_ok=True)
# Produce 3 versions: raw, norm_const, norm_both in two-column layout.
for mode in ["raw", "norm_const", "norm_both"]:
save_path = os.path.join(plots_dir, f"base_sweep_{mode}.png")
plot_base_sweep_convergence(cond_results, condition, desired_theta,
left_list, right_list, save_path,
plot_epoch_range, plot_epoch_step,
normalization_mode=mode)
print(f"Completed base-sweep plotting for condition: {condition}")