136 lines
4.4 KiB
Python
136 lines
4.4 KiB
Python
import os
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import matplotlib
|
|
matplotlib.use("Agg") # Use non-interactive backend
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
import multiprocessing
|
|
|
|
# Define PendulumController class
|
|
class PendulumController(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(4, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
# ODE solver (RK4 method)
|
|
def pendulum_ode_step(state, dt, desired_theta, controller):
|
|
theta, omega, alpha = state
|
|
|
|
def compute_torque(th, om, al):
|
|
inp = torch.tensor([[th, om, al, desired_theta]], dtype=torch.float32)
|
|
with torch.no_grad():
|
|
torque = controller(inp)
|
|
torque = torch.clamp(torque, -250, 250)
|
|
return float(torque)
|
|
|
|
def derivatives(state, torque):
|
|
th, om, al = state
|
|
a = (g / R) * np.sin(th) + torque / (m * R**2)
|
|
return np.array([om, a, 0]) # dtheta, domega, dalpha
|
|
|
|
# Compute RK4 steps
|
|
torque1 = compute_torque(theta, omega, alpha)
|
|
k1 = dt * derivatives(state, torque1)
|
|
|
|
state_k2 = state + 0.5 * k1
|
|
torque2 = compute_torque(state_k2[0], state_k2[1], state_k2[2])
|
|
k2 = dt * derivatives(state_k2, torque2)
|
|
|
|
state_k3 = state + 0.5 * k2
|
|
torque3 = compute_torque(state_k3[0], state_k3[1], state_k3[2])
|
|
k3 = dt * derivatives(state_k3, torque3)
|
|
|
|
state_k4 = state + k3
|
|
torque4 = compute_torque(state_k4[0], state_k4[1], state_k4[2])
|
|
k4 = dt * derivatives(state_k4, torque4)
|
|
|
|
new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6.0
|
|
return new_state
|
|
|
|
# Constants
|
|
g = 9.81 # Gravity
|
|
R = 1.0 # Length of the pendulum
|
|
m = 1.0 # Mass
|
|
dt = 0.02 # Time step
|
|
num_steps = 500 # Simulation time steps
|
|
|
|
# Directory containing controller files
|
|
controller_dir = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/no_time_weight/controllers"
|
|
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
|
|
|
|
# Sorting controllers by epoch
|
|
controller_epochs = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
|
|
sorted_controllers = [x for _, x in sorted(zip(controller_epochs, controller_files))]
|
|
|
|
# **Granularity Control: Select every Nth controller**
|
|
N = 5 # Change this value to adjust granularity (e.g., every 5th controller)
|
|
selected_controllers = sorted_controllers[::N] # Take every Nth controller
|
|
|
|
# Initial condition
|
|
theta0, omega0, alpha0, desired_theta = (-np.pi, -np.pi, 0.0, np.pi / 6) # Example initial condition
|
|
|
|
# Function to run a single controller simulation (for multiprocessing)
|
|
def run_simulation(controller_file):
|
|
epoch = int(controller_file.split('_')[1].split('.')[0])
|
|
|
|
# Load controller
|
|
controller = PendulumController()
|
|
controller.load_state_dict(torch.load(os.path.join(controller_dir, controller_file)))
|
|
controller.eval()
|
|
|
|
# Run simulation
|
|
state = np.array([theta0, omega0, alpha0])
|
|
theta_vals = []
|
|
|
|
for _ in range(num_steps):
|
|
theta_vals.append(state[0])
|
|
state = pendulum_ode_step(state, dt, desired_theta, controller)
|
|
|
|
return epoch, theta_vals
|
|
|
|
# Parallel processing
|
|
if __name__ == "__main__":
|
|
num_workers = min(multiprocessing.cpu_count(), 16) # Limit to 16 workers max
|
|
print(f"Using {num_workers} parallel workers...")
|
|
print(f"Processing every {N}th controller, total controllers used: {len(selected_controllers)}")
|
|
|
|
with multiprocessing.Pool(processes=num_workers) as pool:
|
|
results = pool.map(run_simulation, selected_controllers)
|
|
|
|
# Sort results by epoch
|
|
results.sort(key=lambda x: x[0])
|
|
epochs, theta_over_epochs = zip(*results)
|
|
|
|
# Convert results to NumPy arrays
|
|
theta_over_epochs = np.array(theta_over_epochs)
|
|
|
|
# Create 3D plot
|
|
fig = plt.figure(figsize=(10, 7))
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
|
|
# Meshgrid for 3D plotting
|
|
E, T = np.meshgrid(epochs, np.arange(num_steps) * dt)
|
|
|
|
# Plot surface
|
|
ax.plot_surface(E, T, theta_over_epochs.T, cmap="viridis")
|
|
|
|
# Labels
|
|
ax.set_xlabel("Epoch")
|
|
ax.set_ylabel("Time (s)")
|
|
ax.set_zlabel("Theta (rad)")
|
|
ax.set_title(f"Pendulum Angle Evolution Over Training Epochs (Granularity N={N})")
|
|
|
|
plt.savefig("pendulum_plot.png", dpi=1000, bbox_inches="tight")
|
|
print("Saved plot as 'pendulum_plot.png'.")
|