import os import sys import json import torch import numpy as np import pandas as pd import matplotlib.pyplot as plt from adjustText import adjust_text # For automatic label adjustment # --------------------------------------------------------------------- # Access your custom modules # --------------------------------------------------------------------- sys.path.append("/home/judson/Neural-Networks-in-GNC/inverted_pendulum") from training.time_weighting_functions import weight_functions from analysis.analysis_conditions import analysis_conditions from numpy import pi # --------------------------------------------------------------------- # User parameters # --------------------------------------------------------------------- condition_base_dir = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/time_weighting" final_epoch = 25 # The epoch at which to measure final loss convergence_ref_epoch = 100 # The epoch used to define "final constant loss" convergence_threshold_percentage = 10 # e.g., 10% time_start, time_end, time_points = 0, 10, 1000 t_span = torch.linspace(time_start, time_end, time_points) dt = t_span[1] - t_span[0] # --------------------------------------------------------------------- # 1) Compute weighting metrics (t_mean, t_median, R) for each function # --------------------------------------------------------------------- def compute_metrics(weight_fn, t_span, t_end, min_val=0.01): """ Returns (t_mean, t_median, R) for a given weighting function. Here, R is computed as: R = (integral from t_end/2 to t_end) / (integral from 0 to t_end/2), so that a smaller R indicates more emphasis on early times. """ weights = weight_fn(t_span, t_max=t_end, min_val=min_val) total_weight = torch.trapz(weights, t_span) # Weighted mean t_mean = torch.trapz(t_span * weights, t_span) / total_weight # Weighted median cum_integral = torch.cumsum(weights, dim=0) * dt half_total = total_weight / 2.0 idx = (cum_integral >= half_total).nonzero()[0, 0] t_median = t_span[idx].item() # Inverse ratio: Late vs Early so that a lower value means more early weighting early_mask = t_span <= (t_end / 2) late_mask = t_span > (t_end / 2) early_integral = torch.trapz(weights[early_mask], t_span[early_mask]) late_integral = torch.trapz(weights[late_mask], t_span[late_mask]) R = (late_integral / early_integral).item() if early_integral > 0 else np.nan return t_mean.item(), t_median, R # Ordered weighting functions ordered_functions = [ "constant", "linear", "linear_mirrored", "quadratic", "quadratic_mirrored", "cubic", "cubic_mirrored", "square_root", "square_root_mirrored", "cubic_root", "cubic_root_mirrored", "inverse", "inverse_mirrored", "inverse_squared", "inverse_squared_mirrored", "inverse_cubed", "inverse_cubed_mirrored" ] metrics_dict = {} for fn_name in ordered_functions: fn = weight_functions[fn_name] t_mean, t_median, R = compute_metrics(fn, t_span, time_end, min_val=0.01) metrics_dict[fn_name] = {"t_mean": t_mean, "t_median": t_median, "R": R} # --------------------------------------------------------------------- # 2) Helper to compute loss using constant weighting # --------------------------------------------------------------------- def compute_constant_loss(theta_array, time_array, desired_theta): """ Computes the mean squared error loss: mean((theta - desired_theta)^2). (This is a simplified loss; adjust if you wish to include weighting.) """ theta_tensor = torch.tensor(theta_array, dtype=torch.float32) desired_tensor = torch.full_like(theta_tensor, desired_theta) loss_val = torch.mean((theta_tensor - desired_tensor)**2) return loss_val.item() # --------------------------------------------------------------------- # 3) Parse each condition folder and gather final loss & convergence data # --------------------------------------------------------------------- def get_final_loss_and_convergence(data_dict, desired_theta, final_epoch, ref_epoch): """ data_dict: loaded JSON for a weighting function with keys: "epochs", "theta_over_epochs" (list of arrays), "time" (array) Returns final loss (computed at final_epoch) and placeholder None. """ epochs = data_dict["epochs"] theta_over_epochs = data_dict["theta_over_epochs"] time_array = data_dict["time"] if final_epoch not in epochs: return None, None final_idx = epochs.index(final_epoch) final_theta = theta_over_epochs[final_idx] final_loss = compute_constant_loss(final_theta, time_array, desired_theta) return final_loss, None def find_convergence_epoch(data_dict, desired_theta, threshold): """ Returns the first epoch whose loss (computed with constant MSE) is <= threshold. If none is found, returns np.nan. """ epochs = data_dict["epochs"] theta_over_epochs = data_dict["theta_over_epochs"] time_array = data_dict["time"] for ep, thetas in zip(epochs, theta_over_epochs): loss_val = compute_constant_loss(thetas, time_array, desired_theta) if loss_val <= threshold: return ep return np.nan # --------------------------------------------------------------------- # Helper: Compute best-fit line (with R^2) for scatter plot data. # --------------------------------------------------------------------- def safe_compute_best_fit(x, y, log_x=False, log_y=True): """ Computes the best-fit line using linear regression. Filters out NaN and non-positive values if log transforms are used. Returns: xs: sorted x values for plotting the line, y_fit: predicted y values, slope, intercept, and R^2. If insufficient valid data exist, returns (None, None, NaN, NaN, NaN). """ x = np.array(x, dtype=float) y = np.array(y, dtype=float) valid_mask = ~np.isnan(x) & ~np.isnan(y) if log_x: valid_mask &= (x > 0) if log_y: valid_mask &= (y > 0) x = x[valid_mask] y = y[valid_mask] if len(x) < 2: return None, None, np.nan, np.nan, np.nan if log_x: X = np.log(x) else: X = x if log_y: Y = np.log(y) else: Y = y slope, intercept = np.polyfit(X, Y, 1) Y_pred = slope * X + intercept if log_y: y_pred = np.exp(Y_pred) else: y_pred = Y_pred ss_res = np.sum((Y - Y_pred) ** 2) ss_tot = np.sum((Y - np.mean(Y)) ** 2) R2 = 1 - ss_res / ss_tot if ss_tot != 0 else np.nan xs = np.linspace(np.min(x), np.max(x), 100) if log_x: X_line = np.log(xs) else: X_line = xs Y_line = slope * X_line + intercept if log_y: y_fit = np.exp(Y_line) else: y_fit = Y_line return xs, y_fit, slope, intercept, R2 # --------------------------------------------------------------------- # Main loop: Process a condition folder (here, "small_perturbation") # --------------------------------------------------------------------- all_subdirs = sorted(os.listdir(condition_base_dir)) for cond_name in all_subdirs: cond_path = os.path.join(condition_base_dir, cond_name) if not os.path.isdir(cond_path): continue if cond_name not in analysis_conditions: continue print(f"\n=== Processing condition: {cond_name} ===") desired_theta = analysis_conditions[cond_name][-1] data_dir = os.path.join(cond_path, "data") if not os.path.isdir(data_dir): print(f"No 'data' folder for {cond_name}, skipping.") continue cond_data = {} for fname in os.listdir(data_dir): if fname.endswith(".json"): loss_fn_name = fname.replace(".json", "") with open(os.path.join(data_dir, fname), "r") as f: cond_data[loss_fn_name] = json.load(f) if "constant" not in cond_data: print(f"No constant.json found for {cond_name}, skipping.") continue const_epochs = cond_data["constant"]["epochs"] if convergence_ref_epoch not in const_epochs: print(f"Ref epoch {convergence_ref_epoch} not in constant data for {cond_name}, skipping.") continue ref_idx = const_epochs.index(convergence_ref_epoch) ref_theta = cond_data["constant"]["theta_over_epochs"][ref_idx] time_array = cond_data["constant"]["time"] final_const_loss = compute_constant_loss(ref_theta, time_array, desired_theta) threshold = (1 + convergence_threshold_percentage/100) * final_const_loss print(f"Final constant loss at ref epoch ({convergence_ref_epoch}): {final_const_loss:.6f}, threshold = {threshold:.6f}") results = [] for fn_name in ordered_functions: if fn_name not in cond_data: continue data_dict = cond_data[fn_name] final_loss, _ = get_final_loss_and_convergence(data_dict, desired_theta, final_epoch, convergence_ref_epoch) if final_loss is None: continue conv_epoch = find_convergence_epoch(data_dict, desired_theta, threshold) t_median = metrics_dict[fn_name]["t_median"] results.append({ "Function": fn_name, "t_median": t_median, "final_loss": final_loss, "epochs_to_convergence": conv_epoch }) df = pd.DataFrame(results) if df.empty: print(f"No valid data found for {cond_name}.") continue print(df.to_string(index=False)) # Create output folder for plots: condition/plots/centroid_convergence plot_folder = os.path.join(cond_path, "plots", "centroid_convergence") os.makedirs(plot_folder, exist_ok=True) # ---------------------------- # Save composite plot using just t_median (Vertical Layout: 2 subplots) # Top: Loss vs. t_median, Bottom: Epochs to Convergence vs. t_median # ---------------------------- fig, axes = plt.subplots(2, 1, figsize=(8, 12)) # Subplot 1: t_median vs. final_loss axes[0].scatter(df["t_median"], df["final_loss"], color="green") axes[0].set_xlabel(r"$t_{median}$") axes[0].set_ylabel(f"Loss at Epoch {final_epoch}") axes[0].set_yscale("log") axes[0].set_title(f"{cond_name}: Loss vs. $t_{{median}}$") xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["t_median"], df["final_loss"], log_x=False, log_y=True) if xs is not None: axes[0].plot(xs, y_fit, "k--", label=f"Fit: slope={slope:.3f}, int={intercept:.3f}\n$R^2$={R2:.3f}") axes[0].legend(fontsize=8) texts = [] for i, row in df.iterrows(): if np.isfinite(row["t_median"]) and np.isfinite(row["final_loss"]): txt = axes[0].text(row["t_median"], row["final_loss"], row["Function"], fontsize=8) texts.append(txt) adjust_text(texts, ax=axes[0], arrowprops=dict(arrowstyle="->", color="gray", lw=0.5)) # Subplot 2: t_median vs. epochs_to_convergence axes[1].scatter(df["t_median"], df["epochs_to_convergence"], color="blue") axes[1].set_xlabel(r"$t_{median}$") axes[1].set_ylabel("Epochs to Convergence") axes[1].set_yscale("log") axes[1].set_title(f"{cond_name}: Convergence vs. $t_{{median}}$") xs, y_fit, slope, intercept, R2 = safe_compute_best_fit(df["t_median"], df["epochs_to_convergence"], log_x=False, log_y=True) if xs is not None: axes[1].plot(xs, y_fit, "k--", label=f"Fit: slope={slope:.3f}, int={intercept:.3f}\n$R^2$={R2:.3f}") axes[1].legend(fontsize=8) texts = [] for i, row in df.iterrows(): if np.isfinite(row["t_median"]) and np.isfinite(row["epochs_to_convergence"]): txt = axes[1].text(row["t_median"], row["epochs_to_convergence"], row["Function"], fontsize=8) texts.append(txt) adjust_text(texts, ax=axes[1], arrowprops=dict(arrowstyle="->", color="gray", lw=0.5)) plt.tight_layout() composite_plot_path = os.path.join(plot_folder, f"t_median_composite.png") plt.savefig(composite_plot_path, dpi=300) plt.close() print(f"Saved t_median composite plot to {composite_plot_path}")