181 lines
6.3 KiB
Python
181 lines
6.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchdiffeq import odeint
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
# ----------------------------------------------------------------
|
|
# 1) 3D Controller: [theta, omega, alpha] -> torque
|
|
# ----------------------------------------------------------------
|
|
class PendulumController3D(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(3, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 1)
|
|
)
|
|
|
|
def forward(self, x_3d):
|
|
"""
|
|
x_3d: shape (batch_size, 3) => [theta, omega, alpha].
|
|
Returns shape: (batch_size, 1) => torque.
|
|
"""
|
|
raw_torque = self.net(x_3d)
|
|
clamped_torque = torch.clamp(raw_torque, -250, 250) # Clamp torque within [-250, 250]
|
|
return clamped_torque
|
|
|
|
|
|
# ----------------------------------------------------------------
|
|
# 2) Define ODE System Using `odeint`
|
|
# ----------------------------------------------------------------
|
|
m = 10.0
|
|
g = 9.81
|
|
R = 1.0
|
|
|
|
class PendulumDynamics3D(nn.Module):
|
|
"""
|
|
Defines the ODE system for [theta, omega, alpha] with torque tracking.
|
|
"""
|
|
|
|
def __init__(self, controller):
|
|
super().__init__()
|
|
self.controller = controller
|
|
|
|
def forward(self, t, state):
|
|
"""
|
|
state: (batch_size, 4) => [theta, omega, alpha, tau_prev]
|
|
Returns: (batch_size, 4) => [dtheta/dt, domega/dt, dalpha/dt, dtau/dt]
|
|
"""
|
|
|
|
theta = state[:, 0]
|
|
omega = state[:, 1]
|
|
alpha = state[:, 2]
|
|
tau_prev = state[:, 3]
|
|
|
|
# Create tensor input for controller: [theta, omega, alpha]
|
|
input_3d = torch.stack([theta, omega, alpha], dim=1) # shape (batch_size, 3)
|
|
|
|
# Compute torque using the controller
|
|
tau = self.controller(input_3d).squeeze(-1) # shape (batch_size,)
|
|
|
|
# Compute desired alpha
|
|
alpha_desired = (g / R) * torch.sin(theta) + tau / (m * R**2)
|
|
|
|
# Define ODE system
|
|
dtheta = omega
|
|
domega = alpha
|
|
dalpha = alpha_desired - alpha # Relaxation term
|
|
dtau = tau - tau_prev # Keep track of torque evolution
|
|
|
|
return torch.stack([dtheta, domega, dalpha, dtau], dim=1) # (batch_size, 4)
|
|
|
|
# ----------------------------------------------------------------
|
|
# 3) Loss Function
|
|
# ----------------------------------------------------------------
|
|
def loss_fn(state_traj, t_span):
|
|
"""
|
|
Computes loss based on the trajectory: exponentially increasing theta^2 penalty over time.
|
|
|
|
Args:
|
|
state_traj: Tensor of shape (time_steps, batch_size, 4)
|
|
t_span: Tensor of time steps (time_steps,)
|
|
|
|
Returns:
|
|
total_loss, (loss_theta, loss_omega, loss_torque)
|
|
"""
|
|
theta = state_traj[:, :, 0] # (time_steps, batch_size)
|
|
omega = state_traj[:, :, 1]
|
|
torque = state_traj[:, :, 3] # tau_prev is stored in state
|
|
|
|
# Quadratic weight factor lambda * t**2
|
|
lambda_factor = 0.5 # Increase for stronger late-time punishment
|
|
time_weights = (lambda_factor * t_span**2).unsqueeze(1) # Shape: (time_steps, 1)
|
|
|
|
# Apply increasing penalty over time
|
|
loss_theta = 1e2 * torch.mean(time_weights * (torch.cos(theta) - 1)**2)
|
|
#loss_theta = 1e-1 * torch.mean(time_weights * theta**2)
|
|
loss_omega = 1e-1 * torch.mean(omega**2)
|
|
loss_torque = 1e-5 * torch.mean(torque**2)
|
|
|
|
# Extract the final theta value from the trajectory
|
|
final_theta = state_traj[-1, :, 0] # (batch_size,)
|
|
|
|
# Compute the loss as the squared error from the target theta
|
|
loss_final_theta = torch.mean(final_theta ** 2) # Mean squared error
|
|
|
|
total_loss = loss_theta #+ loss_omega + loss_torque
|
|
return total_loss, (loss_theta, loss_omega, loss_torque, loss_final_theta)
|
|
|
|
|
|
# ----------------------------------------------------------------
|
|
# 4) Training Setup
|
|
# ----------------------------------------------------------------
|
|
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Create the controller and pendulum dynamics model
|
|
controller = PendulumController3D().to(device)
|
|
pendulum_dynamics = PendulumDynamics3D(controller).to(device)
|
|
|
|
# Define optimizer
|
|
optimizer = optim.Adam(controller.parameters(), lr=1e-2)
|
|
|
|
# Initial conditions: [theta, omega, alpha, tau_prev]
|
|
initial_conditions = [
|
|
[0.1, 0.0, 0.0, 0.0], # Small perturbation
|
|
[-0.5, 0.0, 0.0, 0.0],
|
|
[6.28, 6.28, 0.0, 0.0],
|
|
[1.57, 0.5, 0.0, 0.0],
|
|
[0.0, -6.28, 0.0, 0.0],
|
|
[1.57, -6.28, 0.0, 0.0],
|
|
]
|
|
|
|
# Convert to torch tensor (batch_size, 4)
|
|
state_0 = torch.tensor(initial_conditions, dtype=torch.float32, device=device)
|
|
|
|
# Time grid
|
|
t_span = torch.linspace(0, 10, 500, device=device) # 10 seconds, 500 steps
|
|
|
|
num_epochs = 100_000
|
|
print_every = 25
|
|
|
|
# ----------------------------------------------------------------
|
|
# 5) Training Loop
|
|
# ----------------------------------------------------------------
|
|
for epoch in range(num_epochs):
|
|
optimizer.zero_grad()
|
|
|
|
# Integrate the ODE
|
|
state_traj = odeint(pendulum_dynamics, state_0, t_span, method='rk4')
|
|
# state_traj shape: (time_steps, batch_size, 4)
|
|
|
|
# Compute loss
|
|
total_loss, (l_theta, l_omega, l_torque, l_final_theta) = loss_fn(state_traj, t_span)
|
|
|
|
# Check for NaN values
|
|
if torch.isnan(total_loss):
|
|
print(f"NaN detected at epoch {epoch}. Skipping step.")
|
|
optimizer.zero_grad()
|
|
continue # Skip this iteration
|
|
|
|
# Backprop
|
|
total_loss.backward()
|
|
#torch.nn.utils.clip_grad_norm_(controller.parameters(), max_norm=1.0) # Fix NaNs
|
|
optimizer.step()
|
|
|
|
|
|
# Print progress
|
|
if epoch % print_every == 0:
|
|
print(f"Epoch {epoch:4d}/{num_epochs} | "
|
|
f"Total: {total_loss.item():.6f} | "
|
|
f"Theta: {l_theta.item():.6f} | "
|
|
f"Omega: {l_omega.item():.6f} | "
|
|
f"Torque: {l_torque.item():.6f} | "
|
|
f"Final Theta: {l_final_theta.item():.6f}")
|
|
|
|
torch.save(controller.state_dict(), "controller_cpu_clamped_quadratic_time_punish.pth")
|
|
print("Model saved as 'controller_cpu_clamped_quadratic_time_punish.pth'.")
|