diff --git a/analysis/controller_across_epochs.py b/analysis/controller_across_epochs.py index 553198b..22e5223 100644 --- a/analysis/controller_across_epochs.py +++ b/analysis/controller_across_epochs.py @@ -2,11 +2,9 @@ import os import numpy as np import torch import torch.nn as nn -import matplotlib -matplotlib.use("Agg") # Use non-interactive backend import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D -import multiprocessing +from multiprocessing import Pool, cpu_count # Define PendulumController class class PendulumController(nn.Module): @@ -61,26 +59,36 @@ def pendulum_ode_step(state, dt, desired_theta, controller): # Constants g = 9.81 # Gravity R = 1.0 # Length of the pendulum -m = 1.0 # Mass +m = 10.0 # Mass dt = 0.02 # Time step num_steps = 500 # Simulation time steps # Directory containing controller files -controller_dir = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/no_time_weight/controllers" +loss_function = "cubic_time_weight" +#controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/{loss_function}/controllers" +controller_dir = f"C:/Users/Judson/Desktop/New Gitea/Neural-Networks-in-GNC/inverted_pendulum/training/{loss_function}/controllers" controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")]) # Sorting controllers by epoch controller_epochs = [int(f.split('_')[1].split('.')[0]) for f in controller_files] sorted_controllers = [x for _, x in sorted(zip(controller_epochs, controller_files))] +# **Epoch Range Selection** +epoch_range = (0, 100) # Set your desired range (e.g., (0, 5000) or (0, 100)) + +filtered_controllers = [ + f for f in sorted_controllers + if epoch_range[0] <= int(f.split('_')[1].split('.')[0]) <= epoch_range[1] +] + # **Granularity Control: Select every Nth controller** -N = 5 # Change this value to adjust granularity (e.g., every 5th controller) -selected_controllers = sorted_controllers[::N] # Take every Nth controller +N = 1 # Change this value to adjust granularity (e.g., every 5th controller) +selected_controllers = filtered_controllers[::N] # Take every Nth controller within the range # Initial condition -theta0, omega0, alpha0, desired_theta = (-np.pi, -np.pi, 0.0, np.pi / 6) # Example initial condition +theta0, omega0, alpha0, desired_theta = (-np.pi, 0, 0.0, 0.0) # Example initial condition -# Function to run a single controller simulation (for multiprocessing) +# Parallel function must return epoch explicitly def run_simulation(controller_file): epoch = int(controller_file.split('_')[1].split('.')[0]) @@ -97,39 +105,48 @@ def run_simulation(controller_file): theta_vals.append(state[0]) state = pendulum_ode_step(state, dt, desired_theta, controller) - return epoch, theta_vals + return epoch, theta_vals # Return epoch with data # Parallel processing if __name__ == "__main__": - num_workers = min(multiprocessing.cpu_count(), 16) # Limit to 16 workers max + num_workers = min(cpu_count(), 16) # Limit to 16 workers max print(f"Using {num_workers} parallel workers...") - print(f"Processing every {N}th controller, total controllers used: {len(selected_controllers)}") - - with multiprocessing.Pool(processes=num_workers) as pool: + + with Pool(processes=num_workers) as pool: results = pool.map(run_simulation, selected_controllers) - # Sort results by epoch - results.sort(key=lambda x: x[0]) - epochs, theta_over_epochs = zip(*results) + # **Sort results by epoch to ensure correct order** + results.sort(key=lambda x: x[0]) + epochs, theta_over_epochs = zip(*results) # Unzip sorted results # Convert results to NumPy arrays theta_over_epochs = np.array(theta_over_epochs) - # Create 3D plot + + # Create 3D line plot fig = plt.figure(figsize=(10, 7)) ax = fig.add_subplot(111, projection='3d') - # Meshgrid for 3D plotting - E, T = np.meshgrid(epochs, np.arange(num_steps) * dt) + time_steps = np.arange(num_steps) * dt # X-axis (time) - # Plot surface - ax.plot_surface(E, T, theta_over_epochs.T, cmap="viridis") + # Plot each controller as a separate line + for epoch, theta_vals in zip(epochs, theta_over_epochs): + ax.plot( + [epoch] * len(time_steps), # Y-axis (epoch stays constant for each line) + time_steps, # X-axis (time) + theta_vals, # Z-axis (theta evolution) + label=f"Epoch {epoch}" if epoch % (N * 10) == 0 else "", # Label some lines for clarity + ) # Labels ax.set_xlabel("Epoch") ax.set_ylabel("Time (s)") ax.set_zlabel("Theta (rad)") - ax.set_title(f"Pendulum Angle Evolution Over Training Epochs (Granularity N={N})") + ax.set_title(f"Pendulum Angle Evolution for {loss_function}") - plt.savefig("pendulum_plot.png", dpi=1000, bbox_inches="tight") - print("Saved plot as 'pendulum_plot.png'.") + # Improve visibility + ax.view_init(elev=20, azim=-135) # Adjust 3D perspective + + plt.savefig(f"{loss_function}.png", dpi=600) + #plt.show() + print(f"Saved plot as '{loss_function}.png'.") diff --git a/analysis/cubic_time_weight.png b/analysis/cubic_time_weight.png new file mode 100644 index 0000000..3b3df25 Binary files /dev/null and b/analysis/cubic_time_weight.png differ diff --git a/analysis/exponential_time_weight.png b/analysis/exponential_time_weight.png new file mode 100644 index 0000000..ef3ccb9 Binary files /dev/null and b/analysis/exponential_time_weight.png differ diff --git a/analysis/inverse_time_weight.png b/analysis/inverse_time_weight.png new file mode 100644 index 0000000..aea8b2d Binary files /dev/null and b/analysis/inverse_time_weight.png differ diff --git a/analysis/linear_time_weight.png b/analysis/linear_time_weight.png new file mode 100644 index 0000000..15a44dc Binary files /dev/null and b/analysis/linear_time_weight.png differ diff --git a/analysis/no_time_weight.png b/analysis/no_time_weight.png new file mode 100644 index 0000000..f9053d6 Binary files /dev/null and b/analysis/no_time_weight.png differ diff --git a/analysis/pendulum_plot.png b/analysis/pendulum_plot.png deleted file mode 100644 index 62b8cf5..0000000 Binary files a/analysis/pendulum_plot.png and /dev/null differ diff --git a/analysis/quadratic_time_weight.png b/analysis/quadratic_time_weight.png new file mode 100644 index 0000000..8c26b67 Binary files /dev/null and b/analysis/quadratic_time_weight.png differ