Inverted-Pendulum-Neural-Ne.../training/base_loss_learning_rate_sweep_training.py

132 lines
5.4 KiB
Python

import torch
import torch.optim as optim
from torchdiffeq import odeint
import numpy as np
import os
import shutil
import csv
import inspect
import math
from PendulumController import PendulumController
from PendulumDynamics import PendulumDynamics
from initial_conditions import initial_conditions
from base_loss_functions import base_loss_functions, normalized_loss
# Device and base controller setup
device = torch.device("cpu")
base_controller_path = "/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/controller_base.pth"
# Initial conditions (theta0, omega0, alpha0, desired_theta)
state_0 = torch.tensor(initial_conditions, dtype=torch.float32, device=device)
# Pendulum constants
m = 10.0
g = 9.81
R = 1.0
# Time grid settings
t_start, t_end, t_points = 0, 10, 1000
t_span = torch.linspace(t_start, t_end, t_points, device=device)
# Output directory setup
base_output_dir = "base_loss_learning_rate_sweep"
os.makedirs(base_output_dir, exist_ok=True)
# Weight decay and training parameters
weight_decay = 0
num_epochs = 200
# Learning rates for the sweep.
learning_rates = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.25, 0.2, 0.16, 0.125, 0.1, 0.08, 0.05, 0.04, 0.02, 0.01, 0.005, 0.0025]
# Define a loss function wrapper for the base loss functions.
def make_loss_fn(base_loss_fn):
def loss_fn(state_traj):
theta = state_traj[:, :, 0] # [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # [batch_size, t_points]
return torch.mean(base_loss_fn(theta, desired_theta))
return loss_fn
# Training loop for each base loss function and each learning rate.
for name, (exponent, base_loss_fn) in base_loss_functions.items():
for lr in learning_rates:
output_dir = os.path.join(base_output_dir, f"{name}/lr_{lr:.3f}")
controllers_dir = os.path.join(output_dir, "controllers")
if os.path.exists(controllers_dir):
shutil.rmtree(controllers_dir)
os.makedirs(controllers_dir, exist_ok=True)
# Load controller, set up dynamics and optimizer.
controller = PendulumController().to(device)
controller.load_state_dict(torch.load(base_controller_path))
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
optimizer = optim.Adam(controller.parameters(), lr=lr, weight_decay=weight_decay)
loss_fn = make_loss_fn(base_loss_fn)
# Configuration and log files.
config_file = os.path.join(output_dir, "training_config.txt")
log_file = os.path.join(output_dir, "training_log.csv")
with open(config_file, "w") as f:
f.write(f"Base controller path: {base_controller_path}\n")
f.write(f"Time Span: {t_start} to {t_end}, Points: {t_points}\n")
f.write(f"Learning Rate: {lr}\n")
f.write(f"Weight Decay: {weight_decay}\n")
f.write(f"\nLoss Function Name: {name}\n")
f.write(f"Loss Function Exponent: {exponent}\n")
f.write("\nCurrent Loss Function (wrapper) Source Code:\n")
f.write(inspect.getsource(loss_fn))
f.write("\nSpecific Base Loss Function Source Code:\n")
f.write(inspect.getsource(base_loss_fn))
f.write("\nNormalized Loss Function Source Code:\n")
f.write(inspect.getsource(normalized_loss))
f.write("\nTraining Cases:\n")
f.write("[theta0, omega0, alpha0, desired_theta]\n")
for case in state_0.cpu().numpy():
f.write(f"{case.tolist()}\n")
with open(log_file, "w", newline="") as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(["Epoch", "Loss"])
# Variables to track loss changes and detect NaN.
previous_losses = []
nan_counter = 0
stop_count = 5 # Number of epochs to check for unchanged loss or consecutive NaN.
# Training loop.
for epoch in range(num_epochs + 1):
optimizer.zero_grad()
state_traj = odeint(pendulum_dynamics, state_0, t_span, method='rk4')
loss = loss_fn(state_traj)
if torch.isnan(loss).item():
nan_counter += 1
if nan_counter >= stop_count:
print(f"Consecutive NaN detected for {stop_count} epochs at epoch {epoch}. Terminating training for {name} with learning rate {lr}.")
break
else:
nan_counter = 0 # Reset if no NaN detected.
loss.backward()
optimizer.step()
# Save the model.
model_file = os.path.join(controllers_dir, f"controller_{epoch}.pth")
torch.save(controller.state_dict(), model_file)
print(f"{model_file} saved with loss: {loss.item()}")
with open(log_file, "a", newline="") as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow([epoch, loss.item()])
# Early stopping if loss does not change significantly.
if len(previous_losses) >= stop_count:
if all(abs(prev_loss - loss.item()) < 1e-6 for prev_loss in previous_losses[-stop_count:]):
print(f"Loss unchanged for {stop_count} epochs at epoch {epoch}. Terminating training for {name} with learning rate {lr}.")
break
previous_losses.append(loss.item())
print("Training complete. Models and logs are saved under respective directories for each loss function.")