Plotted controller max normalzied across epoch. Also training average normalized

This commit is contained in:
judsonupchurch 2025-02-18 00:40:29 +00:00
parent 071669696b
commit 28c5d14fe8
96 changed files with 18284 additions and 25091 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 972 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 945 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.0 MiB

View File

@ -9,6 +9,13 @@ from multiprocessing import Pool, cpu_count
# Define PendulumController class
from PendulumController import PendulumController
# Constants
g = 9.81 # Gravity
R = 1.0 # Length of the pendulum
m = 10.0 # Mass
dt = 0.02 # Time step
num_steps = 500 # Simulation time steps
# ODE solver (RK4 method)
def pendulum_ode_step(state, dt, desired_theta, controller):
theta, omega, alpha = state
@ -44,41 +51,9 @@ def pendulum_ode_step(state, dt, desired_theta, controller):
new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6.0
return new_state
# Constants
g = 9.81 # Gravity
R = 1.0 # Length of the pendulum
m = 10.0 # Mass
dt = 0.02 # Time step
num_steps = 500 # Simulation time steps
# Directory containing controller files
loss_function = "quadratic"
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/training/{loss_function}/controllers"
#controller_dir = f"C:/Users/Judson/Desktop/New Gitea/Neural-Networks-in-GNC/inverted_pendulum/training/{loss_function}/controllers"
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
# Sorting controllers by epoch
controller_epochs = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
sorted_controllers = [x for _, x in sorted(zip(controller_epochs, controller_files))]
# **Epoch Range Selection**
epoch_range = (0, 1000) # Set your desired range (e.g., (0, 5000) or (0, 100))
filtered_controllers = [
f for f in sorted_controllers
if epoch_range[0] <= int(f.split('_')[1].split('.')[0]) <= epoch_range[1]
]
# **Granularity Control: Select every Nth controller**
N = 1 # Change this value to adjust granularity (e.g., every 5th controller)
selected_controllers = filtered_controllers[::N] # Take every Nth controller within the range
# Initial condition
# theta0, omega0, alpha0, desired_theta = (-np.pi, -2*np.pi, 0.0, -1.3*np.pi) # Example initial condition
theta0, omega0, alpha0, desired_theta = (-np.pi, 0.0, 0.0, 0.0) # Example initial condition
# Parallel function must return epoch explicitly
def run_simulation(controller_file):
def run_simulation(params):
controller_file, initial_condition = params
theta0, omega0, alpha0, desired_theta = initial_condition
epoch = int(controller_file.split('_')[1].split('.')[0])
# Load controller
@ -96,54 +71,100 @@ def run_simulation(controller_file):
return epoch, theta_vals # Return epoch with data
# Parallel processing
# Named initial conditions
initial_conditions = {
"small_perturbation": (0.1*np.pi, 0.0, 0.0, 0.0),
"large_perturbation": (-np.pi, 0.0, 0.0, 0),
"overshoot_vertical_test": (-0.1*np.pi, 2*np.pi, 0.0, 0.0),
"overshoot_angle_test": (0.2*np.pi, 2*np.pi, 0.0, 0.3*np.pi),
"extreme_perturbation": (4*np.pi, 0.0, 0.0, 0),
}
# Loss functions to iterate over
loss_functions = ["constant", "linear", "quadratic", "exponential", "inverse", "inverse_squared"]
epoch_start = 0 # Start of the epoch range
epoch_end = 500 # End of the epoch range
epoch_step = 5 # Interval between epochs
if __name__ == "__main__":
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
print(f"Using {num_workers} parallel workers...")
with Pool(processes=num_workers) as pool:
results = pool.map(run_simulation, selected_controllers)
for condition_name, initial_condition in initial_conditions.items():
full_path = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/max_normalized/{condition_name}"
os.makedirs(full_path, exist_ok=True) # Create directory if it does not exist
for loss_function in loss_functions:
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/max_normalized/{loss_function}/controllers"
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
# Extract epoch numbers and filter based on the defined range and interval
epoch_numbers = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
selected_epochs = [e for e in epoch_numbers if epoch_start <= e <= epoch_end and (e - epoch_start) % epoch_step == 0]
# Sort results by epoch to ensure correct order
results.sort(key=lambda x: x[0])
epochs, theta_over_epochs = zip(*results) # Unzip sorted results
# Filter the controller files to include only those within the selected epochs
selected_controllers = [f for f in controller_files if int(f.split('_')[1].split('.')[0]) in selected_epochs]
selected_controllers.sort(key=lambda f: int(f.split('_')[1].split('.')[0]))
# Convert results to NumPy arrays
theta_over_epochs = np.array(theta_over_epochs)
# Setup parallel processing
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
print(f"Using {num_workers} parallel workers for {loss_function} with initial condition {condition_name}...")
# Create 3D line plot
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
with Pool(processes=num_workers) as pool:
params = [(controller_file, initial_condition) for controller_file in selected_controllers]
results = pool.map(run_simulation, params)
time_steps = np.arange(num_steps) * dt # X-axis (time)
results.sort(key=lambda x: x[0])
epochs, theta_over_epochs = zip(*results)
# Plot each controller as a separate line
for epoch, theta_vals in zip(epochs, theta_over_epochs):
ax.plot(
[epoch] * len(time_steps), # Y-axis (epoch stays constant for each line)
time_steps, # X-axis (time)
theta_vals, # Z-axis (theta evolution)
label=f"Epoch {epoch}" if epoch % (N * 10) == 0 else "", # Label some lines for clarity
)
fig = plt.figure(figsize=(7, 5))
ax = fig.add_subplot(111, projection='3d')
time_steps = np.arange(num_steps) * dt
# Labels
ax.set_xlabel("Epoch")
ax.set_ylabel("Time (s)")
ax.set_zlabel("Theta (rad)")
ax.set_title(f"Pendulum Angle Evolution for {loss_function}")
# Plot the epochs in reverse order because we view it where epoch 0 is in front
for epoch, theta_vals in reversed(list(zip(epochs, theta_over_epochs))):
ax.plot([epoch] * len(time_steps), time_steps, theta_vals)
# Add a horizontal line at desired_theta across all epochs and time steps
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
ax.plot(
epochs_array, # X-axis spanning all epochs
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
[desired_theta] * len(epochs_array), # Constant Z-axis value of desired_theta
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
)
# Improve visibility
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
# Add a horizontal line at desired_theta across all epochs and time steps
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
desired_theta = initial_condition[-1]
ax.plot(
epochs_array, # X-axis spanning all epochs
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
[desired_theta] * len(epochs_array), # Constant Z-axis value of desired_theta
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
)
plt.savefig(f"{loss_function}.png", dpi=600)
#plt.show()
print(f"Saved plot as '{loss_function}.png'.")
ax.set_xlabel("Epoch")
ax.set_ylabel("Time (s)")
ax.set_zlabel("Theta (rad)")
condition_text = f"IC_{'_'.join(map(lambda x: str(round(x, 2)), initial_condition))}"
ax.set_title(f"Pendulum Angle Evolution for {loss_function} and {condition_text}")
# Calculate the range of theta values across all epochs
theta_values = np.concatenate(theta_over_epochs)
theta_min = np.min(theta_values)
theta_max = np.max(theta_values)
# Determine the desired range around the desired_theta
desired_range_min = desired_theta - 1 * np.pi
desired_range_max = desired_theta + 1 * np.pi
# Check if current theta values fall outside the desired range
if theta_min < desired_range_min:
desired_range_min = desired_range_min
else:
desired_range_min = theta_min
if theta_max > desired_range_max:
desired_range_max = desired_range_max
else:
desired_range_max = theta_max
ax.set_zlim(desired_range_min, desired_range_max)
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
plot_filename = os.path.join(full_path, f"{loss_function}.png")
plt.savefig(plot_filename, dpi=300)
plt.close()
print(f"Saved plot as '{plot_filename}'.")

View File

@ -0,0 +1,149 @@
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from multiprocessing import Pool, cpu_count
# Define PendulumController class
from PendulumController import PendulumController
# ODE solver (RK4 method)
def pendulum_ode_step(state, dt, desired_theta, controller):
theta, omega, alpha = state
def compute_torque(th, om, al):
inp = torch.tensor([[th, om, al, desired_theta]], dtype=torch.float32)
with torch.no_grad():
torque = controller(inp)
torque = torch.clamp(torque, -250, 250)
return float(torque)
def derivatives(state, torque):
th, om, al = state
a = (g / R) * np.sin(th) + torque / (m * R**2)
return np.array([om, a, 0]) # dtheta, domega, dalpha
# Compute RK4 steps
torque1 = compute_torque(theta, omega, alpha)
k1 = dt * derivatives(state, torque1)
state_k2 = state + 0.5 * k1
torque2 = compute_torque(state_k2[0], state_k2[1], state_k2[2])
k2 = dt * derivatives(state_k2, torque2)
state_k3 = state + 0.5 * k2
torque3 = compute_torque(state_k3[0], state_k3[1], state_k3[2])
k3 = dt * derivatives(state_k3, torque3)
state_k4 = state + k3
torque4 = compute_torque(state_k4[0], state_k4[1], state_k4[2])
k4 = dt * derivatives(state_k4, torque4)
new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6.0
return new_state
# Constants
g = 9.81 # Gravity
R = 1.0 # Length of the pendulum
m = 10.0 # Mass
dt = 0.02 # Time step
num_steps = 500 # Simulation time steps
# Directory containing controller files
loss_function = "quadratic"
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/training/{loss_function}/controllers"
#controller_dir = f"C:/Users/Judson/Desktop/New Gitea/Neural-Networks-in-GNC/inverted_pendulum/training/{loss_function}/controllers"
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
# Sorting controllers by epoch
controller_epochs = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
sorted_controllers = [x for _, x in sorted(zip(controller_epochs, controller_files))]
# **Epoch Range Selection**
epoch_range = (0, 1000) # Set your desired range (e.g., (0, 5000) or (0, 100))
filtered_controllers = [
f for f in sorted_controllers
if epoch_range[0] <= int(f.split('_')[1].split('.')[0]) <= epoch_range[1]
]
# **Granularity Control: Select every Nth controller**
N = 1 # Change this value to adjust granularity (e.g., every 5th controller)
selected_controllers = filtered_controllers[::N] # Take every Nth controller within the range
# Initial condition
# theta0, omega0, alpha0, desired_theta = (-np.pi, -2*np.pi, 0.0, -1.3*np.pi) # Example initial condition
theta0, omega0, alpha0, desired_theta = (-np.pi, 0.0, 0.0, 0.0) # Example initial condition
# Parallel function must return epoch explicitly
def run_simulation(controller_file):
epoch = int(controller_file.split('_')[1].split('.')[0])
# Load controller
controller = PendulumController()
controller.load_state_dict(torch.load(os.path.join(controller_dir, controller_file)))
controller.eval()
# Run simulation
state = np.array([theta0, omega0, alpha0])
theta_vals = []
for _ in range(num_steps):
theta_vals.append(state[0])
state = pendulum_ode_step(state, dt, desired_theta, controller)
return epoch, theta_vals # Return epoch with data
# Parallel processing
if __name__ == "__main__":
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
print(f"Using {num_workers} parallel workers...")
with Pool(processes=num_workers) as pool:
results = pool.map(run_simulation, selected_controllers)
# Sort results by epoch to ensure correct order
results.sort(key=lambda x: x[0])
epochs, theta_over_epochs = zip(*results) # Unzip sorted results
# Convert results to NumPy arrays
theta_over_epochs = np.array(theta_over_epochs)
# Create 3D line plot
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
time_steps = np.arange(num_steps) * dt # X-axis (time)
# Plot each controller as a separate line
for epoch, theta_vals in zip(epochs, theta_over_epochs):
ax.plot(
[epoch] * len(time_steps), # Y-axis (epoch stays constant for each line)
time_steps, # X-axis (time)
theta_vals, # Z-axis (theta evolution)
label=f"Epoch {epoch}" if epoch % (N * 10) == 0 else "", # Label some lines for clarity
)
# Labels
ax.set_xlabel("Epoch")
ax.set_ylabel("Time (s)")
ax.set_zlabel("Theta (rad)")
ax.set_title(f"Pendulum Angle Evolution for {loss_function}")
# Add a horizontal line at desired_theta across all epochs and time steps
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
ax.plot(
epochs_array, # X-axis spanning all epochs
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
[desired_theta] * len(epochs_array), # Constant Z-axis value of desired_theta
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
)
# Improve visibility
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
plt.savefig(f"{loss_function}.png", dpi=600)
#plt.show()
print(f"Saved plot as '{loss_function}.png'.")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.7 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 564 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 561 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 524 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 534 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 546 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 611 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 733 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 832 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 794 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 752 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 707 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 814 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 494 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 578 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 505 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 502 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 518 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 521 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 517 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 538 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 508 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 526 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 514 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 559 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 513 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 474 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 489 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 498 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 497 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 503 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.8 MiB

25015
nohup.out

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
import torch
import torch.nn as nn
class PendulumController(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(4, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, x):
raw_torque = self.net(x)
return torch.clamp(raw_torque, -250, 250)

View File

@ -0,0 +1,26 @@
import torch
import torch.nn as nn
class PendulumDynamics(nn.Module):
def __init__(self, controller, m:'float'=1, R:'float'=1, g:'float'=9.81):
super().__init__()
self.controller = controller
self.m: 'float' = m
self.R: 'float' = R
self.g: 'float' = g
def forward(self, t, state):
# Get the current values from the state
theta, omega, alpha, desired_theta = state[:, 0], state[:, 1], state[:, 2], state[:, 3]
# Make the input stack for the controller
input = torch.stack([theta, omega, alpha, desired_theta], dim=1)
# Get the torque (the output of the neural network)
tau = self.controller(input).squeeze(-1)
# Relax alpha
alpha_desired = (self.g / self.R) * torch.sin(theta) + tau / (self.m * self.R**2)
dalpha = alpha_desired - alpha
return torch.stack([omega, alpha, dalpha, torch.zeros_like(desired_theta)], dim=1)

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

View File

@ -0,0 +1,608 @@
Epoch,Loss
0,733.2171630859375
1,747.2186279296875
2,173.32571411132812
3,113.02052307128906
4,52.3023796081543
5,29.964460372924805
6,22.173477172851562
7,17.20752716064453
8,14.393138885498047
9,13.495450019836426
10,12.72439193725586
11,12.004809379577637
12,11.283649444580078
13,10.60710620880127
14,9.974536895751953
15,9.39864730834961
16,8.96867847442627
17,8.593585968017578
18,8.431589126586914
19,8.36359977722168
20,8.257978439331055
21,8.131260871887207
22,7.987368583679199
23,7.828018665313721
24,7.654465198516846
25,7.468804836273193
26,7.282363414764404
27,7.1038055419921875
28,6.926774024963379
29,6.756581783294678
30,6.605012893676758
31,6.480014324188232
32,6.387184143066406
33,6.323246955871582
34,6.280386924743652
35,6.2463226318359375
36,6.210468769073486
37,6.165651798248291
38,6.10935640335083
39,6.044148921966553
40,5.975587844848633
41,5.909937381744385
42,5.85064697265625
43,5.797428131103516
44,5.749576091766357
45,5.7068305015563965
46,5.668132305145264
47,5.631911754608154
48,5.5969085693359375
49,5.562211513519287
50,5.527288436889648
51,5.491824626922607
52,5.455735206604004
53,5.4193949699401855
54,5.383573055267334
55,5.350064754486084
56,5.319507598876953
57,5.289746284484863
58,5.259618759155273
59,5.228979110717773
60,5.198094844818115
61,5.167768955230713
62,5.1394476890563965
63,5.113860607147217
64,5.0905303955078125
65,5.070118427276611
66,5.05070686340332
67,5.031530857086182
68,5.012385845184326
69,4.993386268615723
70,4.974751949310303
71,4.956557273864746
72,4.939281940460205
73,4.923177719116211
74,4.9080328941345215
75,4.893461227416992
76,4.879059791564941
77,4.86448860168457
78,4.8493499755859375
79,4.833917617797852
80,4.818392753601074
81,4.80300760269165
82,4.788031101226807
83,4.773463249206543
84,4.759511947631836
85,4.746511459350586
86,4.734636306762695
87,4.72379207611084
88,4.713471412658691
89,4.7037153244018555
90,4.694935321807861
91,4.687223434448242
92,4.680534362792969
93,4.674868106842041
94,4.670145511627197
95,4.666247367858887
96,4.6630682945251465
97,4.660345554351807
98,4.6577630043029785
99,4.655178070068359
100,4.65253210067749
101,4.649806499481201
102,4.64708137512207
103,4.644402980804443
104,4.641839504241943
105,4.639404296875
106,4.637064456939697
107,4.634809970855713
108,4.632608890533447
109,4.63045597076416
110,4.6283416748046875
111,4.626272201538086
112,4.624259948730469
113,4.622305870056152
114,4.620428085327148
115,4.618651866912842
116,4.616915702819824
117,4.615218162536621
118,4.613545894622803
119,4.6118950843811035
120,4.610260963439941
121,4.608641147613525
122,4.607039928436279
123,4.6054606437683105
124,4.603922367095947
125,4.602435111999512
126,4.601007461547852
127,4.599658966064453
128,4.598391532897949
129,4.59716796875
130,4.595986843109131
131,4.594844341278076
132,4.593727111816406
133,4.592639446258545
134,4.591578960418701
135,4.590541839599609
136,4.589527130126953
137,4.588528633117676
138,4.587544918060303
139,4.586572647094727
140,4.585610389709473
141,4.58465576171875
142,4.583714008331299
143,4.582784652709961
144,4.58186149597168
145,4.580949783325195
146,4.58004903793335
147,4.579157829284668
148,4.578279972076416
149,4.5774149894714355
150,4.5765604972839355
151,4.575720310211182
152,4.574892044067383
153,4.574074745178223
154,4.573271751403809
155,4.57248067855835
156,4.571700096130371
157,4.570932388305664
158,4.570174694061279
159,4.569428443908691
160,4.568695545196533
161,4.567972183227539
162,4.567262649536133
163,4.566561698913574
164,4.565873622894287
165,4.565197944641113
166,4.564536094665527
167,4.563884258270264
168,4.563241958618164
169,4.562610626220703
170,4.561987400054932
171,4.561375141143799
172,4.56077241897583
173,4.560176849365234
174,4.559585094451904
175,4.558999061584473
176,4.558421611785889
177,4.557851791381836
178,4.55728816986084
179,4.556730270385742
180,4.556179523468018
181,4.555635452270508
182,4.555098533630371
183,4.554568767547607
184,4.554043769836426
185,4.553525447845459
186,4.553014755249023
187,4.552509307861328
188,4.552009105682373
189,4.551515102386475
190,4.551027297973633
191,4.550544738769531
192,4.55006742477417
193,4.549595832824707
194,4.549131393432617
195,4.5486741065979
196,4.548222064971924
197,4.547779083251953
198,4.547341823577881
199,4.546910762786865
200,4.546485424041748
201,4.546065330505371
202,4.545650959014893
203,4.545242786407471
204,4.544842720031738
205,4.5444488525390625
206,4.544060230255127
207,4.54367733001709
208,4.543309211730957
209,4.5429511070251465
210,4.542597770690918
211,4.542244911193848
212,4.541895389556885
213,4.541547775268555
214,4.541203498840332
215,4.540863513946533
216,4.540529727935791
217,4.5402045249938965
218,4.539884090423584
219,4.539568901062012
220,4.539254665374756
221,4.538944721221924
222,4.538638591766357
223,4.538336277008057
224,4.538040637969971
225,4.537751197814941
226,4.537464618682861
227,4.537179470062256
228,4.536898612976074
229,4.536619186401367
230,4.536343574523926
231,4.536070823669434
232,4.535802364349365
233,4.535538196563721
234,4.535276889801025
235,4.535017967224121
236,4.534761428833008
237,4.534510135650635
238,4.5342607498168945
239,4.534013271331787
240,4.533768653869629
241,4.533526420593262
242,4.533287525177002
243,4.533048629760742
244,4.532813549041748
245,4.532580375671387
246,4.532350063323975
247,4.5321245193481445
248,4.531898021697998
249,4.531674861907959
250,4.531455039978027
251,4.531236171722412
252,4.531019687652588
253,4.530806064605713
254,4.530594348907471
255,4.530384540557861
256,4.530177593231201
257,4.529972076416016
258,4.529767990112305
259,4.529566287994385
260,4.529365539550781
261,4.529166221618652
262,4.5289692878723145
263,4.528773784637451
264,4.5285797119140625
265,4.52838659286499
266,4.528195858001709
267,4.528006553649902
268,4.52781867980957
269,4.527633190155029
270,4.527448654174805
271,4.527266502380371
272,4.527083396911621
273,4.52690315246582
274,4.526723861694336
275,4.526546001434326
276,4.526369571685791
277,4.526192665100098
278,4.526019096374512
279,4.525848388671875
280,4.525679588317871
281,4.5255126953125
282,4.5253472328186035
283,4.525182723999023
284,4.525018692016602
285,4.524855613708496
286,4.524693489074707
287,4.524532794952393
288,4.524373531341553
289,4.524214267730713
290,4.5240559577941895
291,4.523898601531982
292,4.523742198944092
293,4.523586273193359
294,4.523431301116943
295,4.52327823638916
296,4.523126602172852
297,4.522974491119385
298,4.522823333740234
299,4.5226731300354
300,4.522523403167725
301,4.522375106811523
302,4.522228240966797
303,4.522083282470703
304,4.521938800811768
305,4.521796703338623
306,4.521655559539795
307,4.521514415740967
308,4.521374225616455
309,4.521234035491943
310,4.521094799041748
311,4.520956993103027
312,4.520818710327148
313,4.520682334899902
314,4.520545959472656
315,4.520410060882568
316,4.520275115966797
317,4.5201416015625
318,4.520009517669678
319,4.519878387451172
320,4.519747734069824
321,4.519618034362793
322,4.51948881149292
323,4.519359588623047
324,4.519231796264648
325,4.519104957580566
326,4.518978118896484
327,4.518850803375244
328,4.5187249183654785
329,4.518599510192871
330,4.518474578857422
331,4.518350601196289
332,4.518225193023682
333,4.518101215362549
334,4.517977237701416
335,4.517852783203125
336,4.51772928237915
337,4.517606258392334
338,4.517481803894043
339,4.517359733581543
340,4.517236709594727
341,4.517114162445068
342,4.51699161529541
343,4.516868591308594
344,4.516744613647461
345,4.516623020172119
346,4.516499996185303
347,4.5163774490356445
348,4.516254901885986
349,4.5161333084106445
350,4.5160112380981445
351,4.5158891677856445
352,4.515769004821777
353,4.5156474113464355
354,4.51552677154541
355,4.515407085418701
356,4.515288829803467
357,4.515170574188232
358,4.515053749084473
359,4.5149383544921875
360,4.514822483062744
361,4.514708042144775
362,4.5145955085754395
363,4.514482498168945
364,4.514371395111084
365,4.5142598152160645
366,4.5141496658325195
367,4.514039993286133
368,4.513929843902588
369,4.513820171356201
370,4.513709545135498
371,4.513599872589111
372,4.51348876953125
373,4.5133771896362305
374,4.513265132904053
375,4.513152122497559
376,4.51303768157959
377,4.512922763824463
378,4.512807846069336
379,4.512691974639893
380,4.512577056884766
381,4.512462139129639
382,4.512347221374512
383,4.512232780456543
384,4.512118339538574
385,4.5120038986206055
386,4.511888027191162
387,4.511774063110352
388,4.511661052703857
389,4.5115485191345215
390,4.511435508728027
391,4.511322975158691
392,4.5112104415893555
393,4.511097431182861
394,4.510985851287842
395,4.510873317718506
396,4.510761260986328
397,4.510649681091309
398,4.510537624359131
399,4.510426998138428
400,4.510315895080566
401,4.510205268859863
402,4.510096073150635
403,4.509986400604248
404,4.5098772048950195
405,4.509767532348633
406,4.509658336639404
407,4.509550094604492
408,4.509440898895264
409,4.509332180023193
410,4.509223937988281
411,4.509115219116211
412,4.509006500244141
413,4.508897304534912
414,4.508789539337158
415,4.508681774139404
416,4.508574485778809
417,4.508468151092529
418,4.508361339569092
419,4.5082550048828125
420,4.508149147033691
421,4.5080437660217285
422,4.507938385009766
423,4.5078325271606445
424,4.507728099822998
425,4.507623672485352
426,4.507519245147705
427,4.507416725158691
428,4.5073137283325195
429,4.507210731506348
430,4.50710916519165
431,4.507007122039795
432,4.506906032562256
433,4.506804466247559
434,4.506704330444336
435,4.50660514831543
436,4.506505966186523
437,4.506407260894775
438,4.5063090324401855
439,4.50621223449707
440,4.506114959716797
441,4.50601863861084
442,4.505922794342041
443,4.5058274269104
444,4.50573205947876
445,4.505638122558594
446,4.505544185638428
447,4.505451679229736
448,4.505359172821045
449,4.505267143249512
450,4.505176067352295
451,4.505085468292236
452,4.504995822906494
453,4.504907131195068
454,4.504818439483643
455,4.504730701446533
456,4.504644393920898
457,4.504558086395264
458,4.504471778869629
459,4.504387378692627
460,4.504302024841309
461,4.5042195320129395
462,4.504136085510254
463,4.504052639007568
464,4.503970623016357
465,4.503889083862305
466,4.503807544708252
467,4.503726005554199
468,4.503644943237305
469,4.503564357757568
470,4.503483295440674
471,4.503403186798096
472,4.503323078155518
473,4.503242015838623
474,4.5031633377075195
475,4.5030837059021
476,4.50300407409668
477,4.502924919128418
478,4.502845764160156
479,4.5027666091918945
480,4.502687931060791
481,4.5026092529296875
482,4.502531051635742
483,4.502452373504639
484,4.502374172210693
485,4.502295970916748
486,4.502218246459961
487,4.502140045166016
488,4.50206184387207
489,4.501983642578125
490,4.501905918121338
491,4.501828193664551
492,4.5017499923706055
493,4.501672744750977
494,4.5015950202941895
495,4.501518249511719
496,4.501441955566406
497,4.501365661621094
498,4.5012898445129395
499,4.501214981079102
500,4.501140117645264
501,4.501065254211426
502,4.500990867614746
503,4.500916004180908
504,4.5008416175842285
505,4.500767707824707
506,4.5006937980651855
507,4.500619411468506
508,4.500545978546143
509,4.500472068786621
510,4.5003981590271
511,4.500324249267578
512,4.500250816345215
513,4.500176906585693
514,4.50010347366333
515,4.500030517578125
516,4.4999566078186035
517,4.499882698059082
518,4.4998087882995605
519,4.499735355377197
520,4.499661445617676
521,4.499587535858154
522,4.499513149261475
523,4.499438762664795
524,4.499364852905273
525,4.499290466308594
526,4.499216079711914
527,4.499140739440918
528,4.499065399169922
529,4.498990058898926
530,4.498913764953613
531,4.498837947845459
532,4.498762607574463
533,4.49868631362915
534,4.4986114501953125
535,4.498536109924316
536,4.498461723327637
537,4.498386859893799
538,4.498311996459961
539,4.4982380867004395
540,4.498164653778076
541,4.498090744018555
542,4.498019218444824
543,4.497945308685303
544,4.497872829437256
545,4.497801303863525
546,4.497729301452637
547,4.49765682220459
548,4.497586250305176
549,4.49751615524292
550,4.497445106506348
551,4.49737548828125
552,4.497305870056152
553,4.49723482131958
554,4.497164726257324
555,4.497094631195068
556,4.497025012969971
557,4.496954441070557
558,4.496883392333984
559,4.496815204620361
560,4.4967451095581055
561,4.496673107147217
562,4.496603488922119
563,4.4965338706970215
564,4.496462821960449
565,4.496394157409668
566,4.496323585510254
567,4.496251583099365
568,4.496183395385742
569,4.49611234664917
570,4.496041774749756
571,4.4959716796875
572,4.495903491973877
573,4.4958343505859375
574,4.49576473236084
575,4.495695114135742
576,4.4956278800964355
577,4.4955573081970215
578,4.495490074157715
579,4.495420932769775
580,4.495352745056152
581,4.495282173156738
582,4.495216369628906
583,4.495147228240967
584,4.495075225830078
585,4.4950079917907715
586,4.494940757751465
587,4.494872093200684
588,4.494802951812744
589,4.494732856750488
590,4.494668483734131
591,4.494600296020508
592,4.494530200958252
593,4.494461536407471
594,4.494396209716797
595,4.49432897567749
596,4.494260787963867
597,4.494192123413086
598,4.4941229820251465
599,4.494059085845947
600,4.493993759155273
601,4.493923664093018
602,4.493853569030762
603,4.493789196014404
604,4.493722915649414
605,4.493655681610107
606,4.493587970733643
1 Epoch Loss
2 0 733.2171630859375
3 1 747.2186279296875
4 2 173.32571411132812
5 3 113.02052307128906
6 4 52.3023796081543
7 5 29.964460372924805
8 6 22.173477172851562
9 7 17.20752716064453
10 8 14.393138885498047
11 9 13.495450019836426
12 10 12.72439193725586
13 11 12.004809379577637
14 12 11.283649444580078
15 13 10.60710620880127
16 14 9.974536895751953
17 15 9.39864730834961
18 16 8.96867847442627
19 17 8.593585968017578
20 18 8.431589126586914
21 19 8.36359977722168
22 20 8.257978439331055
23 21 8.131260871887207
24 22 7.987368583679199
25 23 7.828018665313721
26 24 7.654465198516846
27 25 7.468804836273193
28 26 7.282363414764404
29 27 7.1038055419921875
30 28 6.926774024963379
31 29 6.756581783294678
32 30 6.605012893676758
33 31 6.480014324188232
34 32 6.387184143066406
35 33 6.323246955871582
36 34 6.280386924743652
37 35 6.2463226318359375
38 36 6.210468769073486
39 37 6.165651798248291
40 38 6.10935640335083
41 39 6.044148921966553
42 40 5.975587844848633
43 41 5.909937381744385
44 42 5.85064697265625
45 43 5.797428131103516
46 44 5.749576091766357
47 45 5.7068305015563965
48 46 5.668132305145264
49 47 5.631911754608154
50 48 5.5969085693359375
51 49 5.562211513519287
52 50 5.527288436889648
53 51 5.491824626922607
54 52 5.455735206604004
55 53 5.4193949699401855
56 54 5.383573055267334
57 55 5.350064754486084
58 56 5.319507598876953
59 57 5.289746284484863
60 58 5.259618759155273
61 59 5.228979110717773
62 60 5.198094844818115
63 61 5.167768955230713
64 62 5.1394476890563965
65 63 5.113860607147217
66 64 5.0905303955078125
67 65 5.070118427276611
68 66 5.05070686340332
69 67 5.031530857086182
70 68 5.012385845184326
71 69 4.993386268615723
72 70 4.974751949310303
73 71 4.956557273864746
74 72 4.939281940460205
75 73 4.923177719116211
76 74 4.9080328941345215
77 75 4.893461227416992
78 76 4.879059791564941
79 77 4.86448860168457
80 78 4.8493499755859375
81 79 4.833917617797852
82 80 4.818392753601074
83 81 4.80300760269165
84 82 4.788031101226807
85 83 4.773463249206543
86 84 4.759511947631836
87 85 4.746511459350586
88 86 4.734636306762695
89 87 4.72379207611084
90 88 4.713471412658691
91 89 4.7037153244018555
92 90 4.694935321807861
93 91 4.687223434448242
94 92 4.680534362792969
95 93 4.674868106842041
96 94 4.670145511627197
97 95 4.666247367858887
98 96 4.6630682945251465
99 97 4.660345554351807
100 98 4.6577630043029785
101 99 4.655178070068359
102 100 4.65253210067749
103 101 4.649806499481201
104 102 4.64708137512207
105 103 4.644402980804443
106 104 4.641839504241943
107 105 4.639404296875
108 106 4.637064456939697
109 107 4.634809970855713
110 108 4.632608890533447
111 109 4.63045597076416
112 110 4.6283416748046875
113 111 4.626272201538086
114 112 4.624259948730469
115 113 4.622305870056152
116 114 4.620428085327148
117 115 4.618651866912842
118 116 4.616915702819824
119 117 4.615218162536621
120 118 4.613545894622803
121 119 4.6118950843811035
122 120 4.610260963439941
123 121 4.608641147613525
124 122 4.607039928436279
125 123 4.6054606437683105
126 124 4.603922367095947
127 125 4.602435111999512
128 126 4.601007461547852
129 127 4.599658966064453
130 128 4.598391532897949
131 129 4.59716796875
132 130 4.595986843109131
133 131 4.594844341278076
134 132 4.593727111816406
135 133 4.592639446258545
136 134 4.591578960418701
137 135 4.590541839599609
138 136 4.589527130126953
139 137 4.588528633117676
140 138 4.587544918060303
141 139 4.586572647094727
142 140 4.585610389709473
143 141 4.58465576171875
144 142 4.583714008331299
145 143 4.582784652709961
146 144 4.58186149597168
147 145 4.580949783325195
148 146 4.58004903793335
149 147 4.579157829284668
150 148 4.578279972076416
151 149 4.5774149894714355
152 150 4.5765604972839355
153 151 4.575720310211182
154 152 4.574892044067383
155 153 4.574074745178223
156 154 4.573271751403809
157 155 4.57248067855835
158 156 4.571700096130371
159 157 4.570932388305664
160 158 4.570174694061279
161 159 4.569428443908691
162 160 4.568695545196533
163 161 4.567972183227539
164 162 4.567262649536133
165 163 4.566561698913574
166 164 4.565873622894287
167 165 4.565197944641113
168 166 4.564536094665527
169 167 4.563884258270264
170 168 4.563241958618164
171 169 4.562610626220703
172 170 4.561987400054932
173 171 4.561375141143799
174 172 4.56077241897583
175 173 4.560176849365234
176 174 4.559585094451904
177 175 4.558999061584473
178 176 4.558421611785889
179 177 4.557851791381836
180 178 4.55728816986084
181 179 4.556730270385742
182 180 4.556179523468018
183 181 4.555635452270508
184 182 4.555098533630371
185 183 4.554568767547607
186 184 4.554043769836426
187 185 4.553525447845459
188 186 4.553014755249023
189 187 4.552509307861328
190 188 4.552009105682373
191 189 4.551515102386475
192 190 4.551027297973633
193 191 4.550544738769531
194 192 4.55006742477417
195 193 4.549595832824707
196 194 4.549131393432617
197 195 4.5486741065979
198 196 4.548222064971924
199 197 4.547779083251953
200 198 4.547341823577881
201 199 4.546910762786865
202 200 4.546485424041748
203 201 4.546065330505371
204 202 4.545650959014893
205 203 4.545242786407471
206 204 4.544842720031738
207 205 4.5444488525390625
208 206 4.544060230255127
209 207 4.54367733001709
210 208 4.543309211730957
211 209 4.5429511070251465
212 210 4.542597770690918
213 211 4.542244911193848
214 212 4.541895389556885
215 213 4.541547775268555
216 214 4.541203498840332
217 215 4.540863513946533
218 216 4.540529727935791
219 217 4.5402045249938965
220 218 4.539884090423584
221 219 4.539568901062012
222 220 4.539254665374756
223 221 4.538944721221924
224 222 4.538638591766357
225 223 4.538336277008057
226 224 4.538040637969971
227 225 4.537751197814941
228 226 4.537464618682861
229 227 4.537179470062256
230 228 4.536898612976074
231 229 4.536619186401367
232 230 4.536343574523926
233 231 4.536070823669434
234 232 4.535802364349365
235 233 4.535538196563721
236 234 4.535276889801025
237 235 4.535017967224121
238 236 4.534761428833008
239 237 4.534510135650635
240 238 4.5342607498168945
241 239 4.534013271331787
242 240 4.533768653869629
243 241 4.533526420593262
244 242 4.533287525177002
245 243 4.533048629760742
246 244 4.532813549041748
247 245 4.532580375671387
248 246 4.532350063323975
249 247 4.5321245193481445
250 248 4.531898021697998
251 249 4.531674861907959
252 250 4.531455039978027
253 251 4.531236171722412
254 252 4.531019687652588
255 253 4.530806064605713
256 254 4.530594348907471
257 255 4.530384540557861
258 256 4.530177593231201
259 257 4.529972076416016
260 258 4.529767990112305
261 259 4.529566287994385
262 260 4.529365539550781
263 261 4.529166221618652
264 262 4.5289692878723145
265 263 4.528773784637451
266 264 4.5285797119140625
267 265 4.52838659286499
268 266 4.528195858001709
269 267 4.528006553649902
270 268 4.52781867980957
271 269 4.527633190155029
272 270 4.527448654174805
273 271 4.527266502380371
274 272 4.527083396911621
275 273 4.52690315246582
276 274 4.526723861694336
277 275 4.526546001434326
278 276 4.526369571685791
279 277 4.526192665100098
280 278 4.526019096374512
281 279 4.525848388671875
282 280 4.525679588317871
283 281 4.5255126953125
284 282 4.5253472328186035
285 283 4.525182723999023
286 284 4.525018692016602
287 285 4.524855613708496
288 286 4.524693489074707
289 287 4.524532794952393
290 288 4.524373531341553
291 289 4.524214267730713
292 290 4.5240559577941895
293 291 4.523898601531982
294 292 4.523742198944092
295 293 4.523586273193359
296 294 4.523431301116943
297 295 4.52327823638916
298 296 4.523126602172852
299 297 4.522974491119385
300 298 4.522823333740234
301 299 4.5226731300354
302 300 4.522523403167725
303 301 4.522375106811523
304 302 4.522228240966797
305 303 4.522083282470703
306 304 4.521938800811768
307 305 4.521796703338623
308 306 4.521655559539795
309 307 4.521514415740967
310 308 4.521374225616455
311 309 4.521234035491943
312 310 4.521094799041748
313 311 4.520956993103027
314 312 4.520818710327148
315 313 4.520682334899902
316 314 4.520545959472656
317 315 4.520410060882568
318 316 4.520275115966797
319 317 4.5201416015625
320 318 4.520009517669678
321 319 4.519878387451172
322 320 4.519747734069824
323 321 4.519618034362793
324 322 4.51948881149292
325 323 4.519359588623047
326 324 4.519231796264648
327 325 4.519104957580566
328 326 4.518978118896484
329 327 4.518850803375244
330 328 4.5187249183654785
331 329 4.518599510192871
332 330 4.518474578857422
333 331 4.518350601196289
334 332 4.518225193023682
335 333 4.518101215362549
336 334 4.517977237701416
337 335 4.517852783203125
338 336 4.51772928237915
339 337 4.517606258392334
340 338 4.517481803894043
341 339 4.517359733581543
342 340 4.517236709594727
343 341 4.517114162445068
344 342 4.51699161529541
345 343 4.516868591308594
346 344 4.516744613647461
347 345 4.516623020172119
348 346 4.516499996185303
349 347 4.5163774490356445
350 348 4.516254901885986
351 349 4.5161333084106445
352 350 4.5160112380981445
353 351 4.5158891677856445
354 352 4.515769004821777
355 353 4.5156474113464355
356 354 4.51552677154541
357 355 4.515407085418701
358 356 4.515288829803467
359 357 4.515170574188232
360 358 4.515053749084473
361 359 4.5149383544921875
362 360 4.514822483062744
363 361 4.514708042144775
364 362 4.5145955085754395
365 363 4.514482498168945
366 364 4.514371395111084
367 365 4.5142598152160645
368 366 4.5141496658325195
369 367 4.514039993286133
370 368 4.513929843902588
371 369 4.513820171356201
372 370 4.513709545135498
373 371 4.513599872589111
374 372 4.51348876953125
375 373 4.5133771896362305
376 374 4.513265132904053
377 375 4.513152122497559
378 376 4.51303768157959
379 377 4.512922763824463
380 378 4.512807846069336
381 379 4.512691974639893
382 380 4.512577056884766
383 381 4.512462139129639
384 382 4.512347221374512
385 383 4.512232780456543
386 384 4.512118339538574
387 385 4.5120038986206055
388 386 4.511888027191162
389 387 4.511774063110352
390 388 4.511661052703857
391 389 4.5115485191345215
392 390 4.511435508728027
393 391 4.511322975158691
394 392 4.5112104415893555
395 393 4.511097431182861
396 394 4.510985851287842
397 395 4.510873317718506
398 396 4.510761260986328
399 397 4.510649681091309
400 398 4.510537624359131
401 399 4.510426998138428
402 400 4.510315895080566
403 401 4.510205268859863
404 402 4.510096073150635
405 403 4.509986400604248
406 404 4.5098772048950195
407 405 4.509767532348633
408 406 4.509658336639404
409 407 4.509550094604492
410 408 4.509440898895264
411 409 4.509332180023193
412 410 4.509223937988281
413 411 4.509115219116211
414 412 4.509006500244141
415 413 4.508897304534912
416 414 4.508789539337158
417 415 4.508681774139404
418 416 4.508574485778809
419 417 4.508468151092529
420 418 4.508361339569092
421 419 4.5082550048828125
422 420 4.508149147033691
423 421 4.5080437660217285
424 422 4.507938385009766
425 423 4.5078325271606445
426 424 4.507728099822998
427 425 4.507623672485352
428 426 4.507519245147705
429 427 4.507416725158691
430 428 4.5073137283325195
431 429 4.507210731506348
432 430 4.50710916519165
433 431 4.507007122039795
434 432 4.506906032562256
435 433 4.506804466247559
436 434 4.506704330444336
437 435 4.50660514831543
438 436 4.506505966186523
439 437 4.506407260894775
440 438 4.5063090324401855
441 439 4.50621223449707
442 440 4.506114959716797
443 441 4.50601863861084
444 442 4.505922794342041
445 443 4.5058274269104
446 444 4.50573205947876
447 445 4.505638122558594
448 446 4.505544185638428
449 447 4.505451679229736
450 448 4.505359172821045
451 449 4.505267143249512
452 450 4.505176067352295
453 451 4.505085468292236
454 452 4.504995822906494
455 453 4.504907131195068
456 454 4.504818439483643
457 455 4.504730701446533
458 456 4.504644393920898
459 457 4.504558086395264
460 458 4.504471778869629
461 459 4.504387378692627
462 460 4.504302024841309
463 461 4.5042195320129395
464 462 4.504136085510254
465 463 4.504052639007568
466 464 4.503970623016357
467 465 4.503889083862305
468 466 4.503807544708252
469 467 4.503726005554199
470 468 4.503644943237305
471 469 4.503564357757568
472 470 4.503483295440674
473 471 4.503403186798096
474 472 4.503323078155518
475 473 4.503242015838623
476 474 4.5031633377075195
477 475 4.5030837059021
478 476 4.50300407409668
479 477 4.502924919128418
480 478 4.502845764160156
481 479 4.5027666091918945
482 480 4.502687931060791
483 481 4.5026092529296875
484 482 4.502531051635742
485 483 4.502452373504639
486 484 4.502374172210693
487 485 4.502295970916748
488 486 4.502218246459961
489 487 4.502140045166016
490 488 4.50206184387207
491 489 4.501983642578125
492 490 4.501905918121338
493 491 4.501828193664551
494 492 4.5017499923706055
495 493 4.501672744750977
496 494 4.5015950202941895
497 495 4.501518249511719
498 496 4.501441955566406
499 497 4.501365661621094
500 498 4.5012898445129395
501 499 4.501214981079102
502 500 4.501140117645264
503 501 4.501065254211426
504 502 4.500990867614746
505 503 4.500916004180908
506 504 4.5008416175842285
507 505 4.500767707824707
508 506 4.5006937980651855
509 507 4.500619411468506
510 508 4.500545978546143
511 509 4.500472068786621
512 510 4.5003981590271
513 511 4.500324249267578
514 512 4.500250816345215
515 513 4.500176906585693
516 514 4.50010347366333
517 515 4.500030517578125
518 516 4.4999566078186035
519 517 4.499882698059082
520 518 4.4998087882995605
521 519 4.499735355377197
522 520 4.499661445617676
523 521 4.499587535858154
524 522 4.499513149261475
525 523 4.499438762664795
526 524 4.499364852905273
527 525 4.499290466308594
528 526 4.499216079711914
529 527 4.499140739440918
530 528 4.499065399169922
531 529 4.498990058898926
532 530 4.498913764953613
533 531 4.498837947845459
534 532 4.498762607574463
535 533 4.49868631362915
536 534 4.4986114501953125
537 535 4.498536109924316
538 536 4.498461723327637
539 537 4.498386859893799
540 538 4.498311996459961
541 539 4.4982380867004395
542 540 4.498164653778076
543 541 4.498090744018555
544 542 4.498019218444824
545 543 4.497945308685303
546 544 4.497872829437256
547 545 4.497801303863525
548 546 4.497729301452637
549 547 4.49765682220459
550 548 4.497586250305176
551 549 4.49751615524292
552 550 4.497445106506348
553 551 4.49737548828125
554 552 4.497305870056152
555 553 4.49723482131958
556 554 4.497164726257324
557 555 4.497094631195068
558 556 4.497025012969971
559 557 4.496954441070557
560 558 4.496883392333984
561 559 4.496815204620361
562 560 4.4967451095581055
563 561 4.496673107147217
564 562 4.496603488922119
565 563 4.4965338706970215
566 564 4.496462821960449
567 565 4.496394157409668
568 566 4.496323585510254
569 567 4.496251583099365
570 568 4.496183395385742
571 569 4.49611234664917
572 570 4.496041774749756
573 571 4.4959716796875
574 572 4.495903491973877
575 573 4.4958343505859375
576 574 4.49576473236084
577 575 4.495695114135742
578 576 4.4956278800964355
579 577 4.4955573081970215
580 578 4.495490074157715
581 579 4.495420932769775
582 580 4.495352745056152
583 581 4.495282173156738
584 582 4.495216369628906
585 583 4.495147228240967
586 584 4.495075225830078
587 585 4.4950079917907715
588 586 4.494940757751465
589 587 4.494872093200684
590 588 4.494802951812744
591 589 4.494732856750488
592 590 4.494668483734131
593 591 4.494600296020508
594 592 4.494530200958252
595 593 4.494461536407471
596 594 4.494396209716797
597 595 4.49432897567749
598 596 4.494260787963867
599 597 4.494192123413086
600 598 4.4941229820251465
601 599 4.494059085845947
602 600 4.493993759155273
603 601 4.493923664093018
604 602 4.493853569030762
605 603 4.493789196014404
606 604 4.493722915649414
607 605 4.493655681610107
608 606 4.493587970733643

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,155 @@
import torch
import torch.optim as optim
from torchdiffeq import odeint
import numpy as np
import os
import shutil
import csv
import inspect
from PendulumController import PendulumController
from PendulumDynamics import PendulumDynamics
# Device setup
device = torch.device("cpu")
# Initial conditions (theta0, omega0, alpha0, desired_theta)
from initial_conditions import initial_conditions
state_0 = torch.tensor(initial_conditions, dtype=torch.float32, device=device)
# Device setup
device = torch.device("cpu")
# Constants
m = 10.0
g = 9.81
R = 1.0
# Time grid
t_start, t_end, t_points = 0, 10, 1000
t_span = torch.linspace(t_start, t_end, t_points, device=device)
# Specify directory for storing results
output_dir = "average_normalized"
os.makedirs(output_dir, exist_ok=True)
# Use a previously generated random seed
random_seed = 4529
# Set the seeds for reproducibility
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Print the chosen random seed
print(f"Random seed for torch and numpy: {random_seed}")
# Initialize controller and dynamics
controller = PendulumController().to(device)
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
# Optimizer setup
learning_rate = 1e-1
weight_decay = 1e-4
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
# Training parameters
num_epochs = 1001
# Define loss functions
def make_loss_fn(weight_fn):
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
return loss_fn
# Define and store weight functions with descriptions, normalized by average weight
weight_functions = {
'constant': {
'function': lambda t: torch.ones_like(t) / torch.ones_like(t).mean(),
'description': 'Constant weight: All weights are 1, normalized by the average (remains 1)'
},
'linear': {
'function': lambda t: (t / t.max()) / (t / t.max()).mean(),
'description': 'Linear weight: Weights increase linearly from 0 to 1, normalized by the average weight'
},
'quadratic': {
'function': lambda t: ((t / t.max()) ** 2) / ((t / t.max()) ** 2).mean(),
'description': 'Quadratic weight: Weights increase quadratically from 0 to 1, normalized by the average weight'
},
'exponential': {
'function': lambda t: (torch.exp(t / t.max() * 2)) / (torch.exp(t / t.max() * 2)).mean(),
'description': 'Exponential weight: Weights increase exponentially, normalized by the average weight'
},
'inverse': {
'function': lambda t: (1 / (t / t.max() + 1)) / (1 / (t / t.max() + 1)).mean(),
'description': 'Inverse weight: Weights decrease inversely, normalized by the average weight'
},
'inverse_squared': {
'function': lambda t: (1 / ((t / t.max() + 1) ** 2)) / (1 / ((t / t.max() + 1) ** 2)).mean(),
'description': 'Inverse squared weight: Weights decrease inversely squared, normalized by the average weight'
}
}
# Training loop for each weight function
for name, weight_info in weight_functions.items():
controller = PendulumController().to(device)
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_fn = make_loss_fn(weight_info['function'])
# File paths
function_output_dir = os.path.join(output_dir, name)
controllers_dir = os.path.join(function_output_dir, "controllers")
# Check if controllers directory exists and remove it
if os.path.exists(controllers_dir):
shutil.rmtree(controllers_dir)
os.makedirs(controllers_dir, exist_ok=True)
config_file = os.path.join(function_output_dir, "training_config.txt")
log_file = os.path.join(function_output_dir, "training_log.csv")
# Overwrite configuration and log files
with open(config_file, "w") as f:
f.write(f"Random Seed: {random_seed}\n")
f.write(f"Time Span: {t_start} to {t_end}, Points: {t_points}\n")
f.write(f"Learning Rate: {learning_rate}\n")
f.write(f"Weight Decay: {weight_decay}\n")
f.write("\nLoss Function:\n")
f.write(inspect.getsource(loss_fn))
f.write("\nTraining Cases:\n")
f.write("[theta0, omega0, alpha0, desired_theta]\n")
for case in state_0.cpu().numpy():
f.write(f"{case.tolist()}\n")
with open(log_file, "w", newline="") as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(["Epoch", "Loss"])
# Training loop
for epoch in range(num_epochs):
optimizer.zero_grad()
state_traj = odeint(pendulum_dynamics, state_0, t_span, method='rk4')
loss = loss_fn(state_traj, t_span)
loss.backward()
optimizer.step()
# Logging
with open(log_file, "a", newline="") as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow([epoch, loss.item()])
# Save the model
model_file = os.path.join(controllers_dir, f"controller_{epoch}.pth")
torch.save(controller.state_dict(), model_file)
print(f"{model_file} saved with loss: {loss}")
print("Training complete. Models and logs are saved under respective directories for each loss function.")

View File

@ -0,0 +1,27 @@
import torch
from torch import pi
initial_conditions = [
[1/6 * pi, 0.0, 0.0, 0.0],
[-1/6 * pi, 0.0, 0.0, 0.0],
[2/3 * pi, 0.0, 0.0, 0.0],
[-2/3 * pi, 0.0, 0.0, 0.0],
[0.0, 1/3 * pi, 0.0, 0.0],
[0.0, -1/3 * pi, 0.0, 0.0],
[0.0, 2 * pi, 0.0, 0.0],
[0.0, -2 * pi, 0.0, 0.0],
[0.0, 0.0, 0.0, 2 * pi],
[0.0, 0.0, 0.0, -2 * pi],
[0.0, 0.0, 0.0, 1/2 * pi],
[0.0, 0.0, 0.0, -1/2 * pi],
[0.0, 0.0, 0.0, 1/3 * pi],
[0.0, 0.0, 0.0, -1/3 * pi],
[1/4 * pi, 1 * pi, 0.0, 0.0],
[-1/4 * pi, -1 * pi, 0.0, 0.0],
[1/2 * pi, -1 * pi, 0.0, 1/3 * pi],
[-1/2 * pi, 1 * pi, 0.0, -1/3 * pi],
[1/4 * pi, 1 * pi, 0.0, 2 * pi],
[-1/4 * pi, -1 * pi, 0.0, 2 * pi],
[1/2 * pi, -1 * pi, 0.0, 4 * pi],
[-1/2 * pi, 1 * pi, 0.0, -4 * pi],
]

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
Random Seed: 4529
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,155 @@
import torch
import torch.optim as optim
from torchdiffeq import odeint
import numpy as np
import os
import shutil
import csv
import inspect
from PendulumController import PendulumController
from PendulumDynamics import PendulumDynamics
# Device setup
device = torch.device("cpu")
# Initial conditions (theta0, omega0, alpha0, desired_theta)
from initial_conditions import initial_conditions
state_0 = torch.tensor(initial_conditions, dtype=torch.float32, device=device)
# Device setup
device = torch.device("cpu")
# Constants
m = 10.0
g = 9.81
R = 1.0
# Time grid
t_start, t_end, t_points = 0, 10, 1000
t_span = torch.linspace(t_start, t_end, t_points, device=device)
# Specify directory for storing results
output_dir = "training"
os.makedirs(output_dir, exist_ok=True)
# Use a previously generated random seed
random_seed = 4529
# Set the seeds for reproducibility
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Print the chosen random seed
print(f"Random seed for torch and numpy: {random_seed}")
# Initialize controller and dynamics
controller = PendulumController().to(device)
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
# Optimizer setup
learning_rate = 1e-1
weight_decay = 1e-4
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
# Training parameters
num_epochs = 1000
# Define loss functions
def make_loss_fn(weight_fn):
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
weights = weight_fn(t_span) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
return loss_fn
# Define and store weight functions with descriptions
weight_functions = {
'constant': {
'function': lambda t: torch.ones_like(t),
'description': 'Constant weight: All weights are 1'
},
'linear': {
'function': lambda t: (t / t.max()) / (t / t.max()).max(),
'description': 'Linear weight: Weights increase linearly from 0 to 1, normalized by max'
},
'quadratic': {
'function': lambda t: ((t / t.max()) ** 2) / ((t / t.max()) ** 2).max(),
'description': 'Quadratic weight: Weights increase quadratically from 0 to 1, normalized by max'
},
'exponential': {
'function': lambda t: (torch.exp(t / t.max() * 2)) / (torch.exp(t / t.max() * 2)).max(),
'description': 'Exponential weight: Weights increase exponentially, normalized by max'
},
'inverse': {
'function': lambda t: (1 / (t / t.max() + 1)) / (1 / (t / t.max() + 1)).max(),
'description': 'Inverse weight: Weights decrease inversely, normalized by max'
},
'inverse_squared': {
'function': lambda t: (1 / ((t / t.max() + 1) ** 2)) / (1 / ((t / t.max() + 1) ** 2)).max(),
'description': 'Inverse squared weight: Weights decrease inversely squared, normalized by max'
}
}
# Training loop for each weight function
for name, weight_info in weight_functions.items():
controller = PendulumController().to(device)
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_fn = make_loss_fn(weight_info['function'])
# File paths
function_output_dir = os.path.join(output_dir, name)
controllers_dir = os.path.join(function_output_dir, "controllers")
# Check if controllers directory exists and remove it
if os.path.exists(controllers_dir):
shutil.rmtree(controllers_dir)
os.makedirs(controllers_dir, exist_ok=True)
config_file = os.path.join(function_output_dir, "training_config.txt")
log_file = os.path.join(function_output_dir, "training_log.csv")
# Overwrite configuration and log files
with open(config_file, "w") as f:
f.write(f"Random Seed: {random_seed}\n")
f.write(f"Time Span: {t_start} to {t_end}, Points: {t_points}\n")
f.write(f"Learning Rate: {learning_rate}\n")
f.write(f"Weight Decay: {weight_decay}\n")
f.write("\nLoss Function:\n")
f.write(inspect.getsource(loss_fn))
f.write("\nTraining Cases:\n")
f.write("[theta0, omega0, alpha0, desired_theta]\n")
for case in state_0.cpu().numpy():
f.write(f"{case.tolist()}\n")
with open(log_file, "w", newline="") as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(["Epoch", "Loss"])
# Training loop
for epoch in range(num_epochs):
optimizer.zero_grad()
state_traj = odeint(pendulum_dynamics, state_0, t_span, method='rk4')
loss = loss_fn(state_traj, t_span)
loss.backward()
optimizer.step()
# Logging
with open(log_file, "a", newline="") as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow([epoch, loss.item()])
# Save the model
model_file = os.path.join(controllers_dir, f"controller_{epoch}.pth")
torch.save(controller.state_dict(), model_file)
print(f"{model_file} saved with loss: {loss}")
print("Training complete. Models and logs are saved under respective directories for each loss function.")

File diff suppressed because it is too large Load Diff