Plotted controller max normalzied across epoch. Also training average normalized
BIN
analysis/average_normalized/IC_-3.14_0.0_0.0_0.0/constant.png
Normal file
|
After Width: | Height: | Size: 972 KiB |
BIN
analysis/average_normalized/IC_-3.14_0.0_0.0_0.0/exponential.png
Normal file
|
After Width: | Height: | Size: 2.2 MiB |
BIN
analysis/average_normalized/IC_-3.14_0.0_0.0_0.0/linear.png
Normal file
|
After Width: | Height: | Size: 945 KiB |
BIN
analysis/average_normalized/IC_-3.14_0.0_0.0_0.0/quadratic.png
Normal file
|
After Width: | Height: | Size: 2.5 MiB |
|
Before Width: | Height: | Size: 4.8 MiB |
|
Before Width: | Height: | Size: 8.0 MiB |
@ -9,6 +9,13 @@ from multiprocessing import Pool, cpu_count
|
|||||||
# Define PendulumController class
|
# Define PendulumController class
|
||||||
from PendulumController import PendulumController
|
from PendulumController import PendulumController
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
g = 9.81 # Gravity
|
||||||
|
R = 1.0 # Length of the pendulum
|
||||||
|
m = 10.0 # Mass
|
||||||
|
dt = 0.02 # Time step
|
||||||
|
num_steps = 500 # Simulation time steps
|
||||||
|
|
||||||
# ODE solver (RK4 method)
|
# ODE solver (RK4 method)
|
||||||
def pendulum_ode_step(state, dt, desired_theta, controller):
|
def pendulum_ode_step(state, dt, desired_theta, controller):
|
||||||
theta, omega, alpha = state
|
theta, omega, alpha = state
|
||||||
@ -44,41 +51,9 @@ def pendulum_ode_step(state, dt, desired_theta, controller):
|
|||||||
new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6.0
|
new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6.0
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
# Constants
|
def run_simulation(params):
|
||||||
g = 9.81 # Gravity
|
controller_file, initial_condition = params
|
||||||
R = 1.0 # Length of the pendulum
|
theta0, omega0, alpha0, desired_theta = initial_condition
|
||||||
m = 10.0 # Mass
|
|
||||||
dt = 0.02 # Time step
|
|
||||||
num_steps = 500 # Simulation time steps
|
|
||||||
|
|
||||||
# Directory containing controller files
|
|
||||||
loss_function = "quadratic"
|
|
||||||
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/training/{loss_function}/controllers"
|
|
||||||
#controller_dir = f"C:/Users/Judson/Desktop/New Gitea/Neural-Networks-in-GNC/inverted_pendulum/training/{loss_function}/controllers"
|
|
||||||
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
|
|
||||||
|
|
||||||
# Sorting controllers by epoch
|
|
||||||
controller_epochs = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
|
|
||||||
sorted_controllers = [x for _, x in sorted(zip(controller_epochs, controller_files))]
|
|
||||||
|
|
||||||
# **Epoch Range Selection**
|
|
||||||
epoch_range = (0, 1000) # Set your desired range (e.g., (0, 5000) or (0, 100))
|
|
||||||
|
|
||||||
filtered_controllers = [
|
|
||||||
f for f in sorted_controllers
|
|
||||||
if epoch_range[0] <= int(f.split('_')[1].split('.')[0]) <= epoch_range[1]
|
|
||||||
]
|
|
||||||
|
|
||||||
# **Granularity Control: Select every Nth controller**
|
|
||||||
N = 1 # Change this value to adjust granularity (e.g., every 5th controller)
|
|
||||||
selected_controllers = filtered_controllers[::N] # Take every Nth controller within the range
|
|
||||||
|
|
||||||
# Initial condition
|
|
||||||
# theta0, omega0, alpha0, desired_theta = (-np.pi, -2*np.pi, 0.0, -1.3*np.pi) # Example initial condition
|
|
||||||
theta0, omega0, alpha0, desired_theta = (-np.pi, 0.0, 0.0, 0.0) # Example initial condition
|
|
||||||
|
|
||||||
# Parallel function must return epoch explicitly
|
|
||||||
def run_simulation(controller_file):
|
|
||||||
epoch = int(controller_file.split('_')[1].split('.')[0])
|
epoch = int(controller_file.split('_')[1].split('.')[0])
|
||||||
|
|
||||||
# Load controller
|
# Load controller
|
||||||
@ -96,44 +71,62 @@ def run_simulation(controller_file):
|
|||||||
|
|
||||||
return epoch, theta_vals # Return epoch with data
|
return epoch, theta_vals # Return epoch with data
|
||||||
|
|
||||||
# Parallel processing
|
# Named initial conditions
|
||||||
|
initial_conditions = {
|
||||||
|
"small_perturbation": (0.1*np.pi, 0.0, 0.0, 0.0),
|
||||||
|
"large_perturbation": (-np.pi, 0.0, 0.0, 0),
|
||||||
|
"overshoot_vertical_test": (-0.1*np.pi, 2*np.pi, 0.0, 0.0),
|
||||||
|
"overshoot_angle_test": (0.2*np.pi, 2*np.pi, 0.0, 0.3*np.pi),
|
||||||
|
"extreme_perturbation": (4*np.pi, 0.0, 0.0, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Loss functions to iterate over
|
||||||
|
loss_functions = ["constant", "linear", "quadratic", "exponential", "inverse", "inverse_squared"]
|
||||||
|
|
||||||
|
|
||||||
|
epoch_start = 0 # Start of the epoch range
|
||||||
|
epoch_end = 500 # End of the epoch range
|
||||||
|
epoch_step = 5 # Interval between epochs
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
for condition_name, initial_condition in initial_conditions.items():
|
||||||
|
full_path = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/max_normalized/{condition_name}"
|
||||||
|
os.makedirs(full_path, exist_ok=True) # Create directory if it does not exist
|
||||||
|
|
||||||
|
for loss_function in loss_functions:
|
||||||
|
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/max_normalized/{loss_function}/controllers"
|
||||||
|
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
|
||||||
|
# Extract epoch numbers and filter based on the defined range and interval
|
||||||
|
epoch_numbers = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
|
||||||
|
selected_epochs = [e for e in epoch_numbers if epoch_start <= e <= epoch_end and (e - epoch_start) % epoch_step == 0]
|
||||||
|
|
||||||
|
# Filter the controller files to include only those within the selected epochs
|
||||||
|
selected_controllers = [f for f in controller_files if int(f.split('_')[1].split('.')[0]) in selected_epochs]
|
||||||
|
selected_controllers.sort(key=lambda f: int(f.split('_')[1].split('.')[0]))
|
||||||
|
|
||||||
|
# Setup parallel processing
|
||||||
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
|
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
|
||||||
print(f"Using {num_workers} parallel workers...")
|
print(f"Using {num_workers} parallel workers for {loss_function} with initial condition {condition_name}...")
|
||||||
|
|
||||||
with Pool(processes=num_workers) as pool:
|
with Pool(processes=num_workers) as pool:
|
||||||
results = pool.map(run_simulation, selected_controllers)
|
params = [(controller_file, initial_condition) for controller_file in selected_controllers]
|
||||||
|
results = pool.map(run_simulation, params)
|
||||||
|
|
||||||
# Sort results by epoch to ensure correct order
|
|
||||||
results.sort(key=lambda x: x[0])
|
results.sort(key=lambda x: x[0])
|
||||||
epochs, theta_over_epochs = zip(*results) # Unzip sorted results
|
epochs, theta_over_epochs = zip(*results)
|
||||||
|
|
||||||
# Convert results to NumPy arrays
|
fig = plt.figure(figsize=(7, 5))
|
||||||
theta_over_epochs = np.array(theta_over_epochs)
|
|
||||||
|
|
||||||
# Create 3D line plot
|
|
||||||
fig = plt.figure(figsize=(10, 7))
|
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
|
time_steps = np.arange(num_steps) * dt
|
||||||
|
|
||||||
time_steps = np.arange(num_steps) * dt # X-axis (time)
|
# Plot the epochs in reverse order because we view it where epoch 0 is in front
|
||||||
|
for epoch, theta_vals in reversed(list(zip(epochs, theta_over_epochs))):
|
||||||
|
ax.plot([epoch] * len(time_steps), time_steps, theta_vals)
|
||||||
|
|
||||||
# Plot each controller as a separate line
|
|
||||||
for epoch, theta_vals in zip(epochs, theta_over_epochs):
|
|
||||||
ax.plot(
|
|
||||||
[epoch] * len(time_steps), # Y-axis (epoch stays constant for each line)
|
|
||||||
time_steps, # X-axis (time)
|
|
||||||
theta_vals, # Z-axis (theta evolution)
|
|
||||||
label=f"Epoch {epoch}" if epoch % (N * 10) == 0 else "", # Label some lines for clarity
|
|
||||||
)
|
|
||||||
|
|
||||||
# Labels
|
|
||||||
ax.set_xlabel("Epoch")
|
|
||||||
ax.set_ylabel("Time (s)")
|
|
||||||
ax.set_zlabel("Theta (rad)")
|
|
||||||
ax.set_title(f"Pendulum Angle Evolution for {loss_function}")
|
|
||||||
|
|
||||||
# Add a horizontal line at desired_theta across all epochs and time steps
|
# Add a horizontal line at desired_theta across all epochs and time steps
|
||||||
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
|
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
|
||||||
|
desired_theta = initial_condition[-1]
|
||||||
ax.plot(
|
ax.plot(
|
||||||
epochs_array, # X-axis spanning all epochs
|
epochs_array, # X-axis spanning all epochs
|
||||||
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
|
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
|
||||||
@ -141,9 +134,37 @@ if __name__ == "__main__":
|
|||||||
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
|
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Improve visibility
|
ax.set_xlabel("Epoch")
|
||||||
|
ax.set_ylabel("Time (s)")
|
||||||
|
ax.set_zlabel("Theta (rad)")
|
||||||
|
condition_text = f"IC_{'_'.join(map(lambda x: str(round(x, 2)), initial_condition))}"
|
||||||
|
ax.set_title(f"Pendulum Angle Evolution for {loss_function} and {condition_text}")
|
||||||
|
|
||||||
|
# Calculate the range of theta values across all epochs
|
||||||
|
theta_values = np.concatenate(theta_over_epochs)
|
||||||
|
theta_min = np.min(theta_values)
|
||||||
|
theta_max = np.max(theta_values)
|
||||||
|
|
||||||
|
# Determine the desired range around the desired_theta
|
||||||
|
desired_range_min = desired_theta - 1 * np.pi
|
||||||
|
desired_range_max = desired_theta + 1 * np.pi
|
||||||
|
|
||||||
|
# Check if current theta values fall outside the desired range
|
||||||
|
if theta_min < desired_range_min:
|
||||||
|
desired_range_min = desired_range_min
|
||||||
|
else:
|
||||||
|
desired_range_min = theta_min
|
||||||
|
|
||||||
|
if theta_max > desired_range_max:
|
||||||
|
desired_range_max = desired_range_max
|
||||||
|
else:
|
||||||
|
desired_range_max = theta_max
|
||||||
|
|
||||||
|
ax.set_zlim(desired_range_min, desired_range_max)
|
||||||
|
|
||||||
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
|
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
|
||||||
|
|
||||||
plt.savefig(f"{loss_function}.png", dpi=600)
|
plot_filename = os.path.join(full_path, f"{loss_function}.png")
|
||||||
#plt.show()
|
plt.savefig(plot_filename, dpi=300)
|
||||||
print(f"Saved plot as '{loss_function}.png'.")
|
plt.close()
|
||||||
|
print(f"Saved plot as '{plot_filename}'.")
|
||||||
149
analysis/controller_across_epochs_old.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
|
||||||
|
# Define PendulumController class
|
||||||
|
from PendulumController import PendulumController
|
||||||
|
|
||||||
|
# ODE solver (RK4 method)
|
||||||
|
def pendulum_ode_step(state, dt, desired_theta, controller):
|
||||||
|
theta, omega, alpha = state
|
||||||
|
|
||||||
|
def compute_torque(th, om, al):
|
||||||
|
inp = torch.tensor([[th, om, al, desired_theta]], dtype=torch.float32)
|
||||||
|
with torch.no_grad():
|
||||||
|
torque = controller(inp)
|
||||||
|
torque = torch.clamp(torque, -250, 250)
|
||||||
|
return float(torque)
|
||||||
|
|
||||||
|
def derivatives(state, torque):
|
||||||
|
th, om, al = state
|
||||||
|
a = (g / R) * np.sin(th) + torque / (m * R**2)
|
||||||
|
return np.array([om, a, 0]) # dtheta, domega, dalpha
|
||||||
|
|
||||||
|
# Compute RK4 steps
|
||||||
|
torque1 = compute_torque(theta, omega, alpha)
|
||||||
|
k1 = dt * derivatives(state, torque1)
|
||||||
|
|
||||||
|
state_k2 = state + 0.5 * k1
|
||||||
|
torque2 = compute_torque(state_k2[0], state_k2[1], state_k2[2])
|
||||||
|
k2 = dt * derivatives(state_k2, torque2)
|
||||||
|
|
||||||
|
state_k3 = state + 0.5 * k2
|
||||||
|
torque3 = compute_torque(state_k3[0], state_k3[1], state_k3[2])
|
||||||
|
k3 = dt * derivatives(state_k3, torque3)
|
||||||
|
|
||||||
|
state_k4 = state + k3
|
||||||
|
torque4 = compute_torque(state_k4[0], state_k4[1], state_k4[2])
|
||||||
|
k4 = dt * derivatives(state_k4, torque4)
|
||||||
|
|
||||||
|
new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6.0
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
g = 9.81 # Gravity
|
||||||
|
R = 1.0 # Length of the pendulum
|
||||||
|
m = 10.0 # Mass
|
||||||
|
dt = 0.02 # Time step
|
||||||
|
num_steps = 500 # Simulation time steps
|
||||||
|
|
||||||
|
# Directory containing controller files
|
||||||
|
loss_function = "quadratic"
|
||||||
|
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/training/{loss_function}/controllers"
|
||||||
|
#controller_dir = f"C:/Users/Judson/Desktop/New Gitea/Neural-Networks-in-GNC/inverted_pendulum/training/{loss_function}/controllers"
|
||||||
|
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
|
||||||
|
|
||||||
|
# Sorting controllers by epoch
|
||||||
|
controller_epochs = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
|
||||||
|
sorted_controllers = [x for _, x in sorted(zip(controller_epochs, controller_files))]
|
||||||
|
|
||||||
|
# **Epoch Range Selection**
|
||||||
|
epoch_range = (0, 1000) # Set your desired range (e.g., (0, 5000) or (0, 100))
|
||||||
|
|
||||||
|
filtered_controllers = [
|
||||||
|
f for f in sorted_controllers
|
||||||
|
if epoch_range[0] <= int(f.split('_')[1].split('.')[0]) <= epoch_range[1]
|
||||||
|
]
|
||||||
|
|
||||||
|
# **Granularity Control: Select every Nth controller**
|
||||||
|
N = 1 # Change this value to adjust granularity (e.g., every 5th controller)
|
||||||
|
selected_controllers = filtered_controllers[::N] # Take every Nth controller within the range
|
||||||
|
|
||||||
|
# Initial condition
|
||||||
|
# theta0, omega0, alpha0, desired_theta = (-np.pi, -2*np.pi, 0.0, -1.3*np.pi) # Example initial condition
|
||||||
|
theta0, omega0, alpha0, desired_theta = (-np.pi, 0.0, 0.0, 0.0) # Example initial condition
|
||||||
|
|
||||||
|
# Parallel function must return epoch explicitly
|
||||||
|
def run_simulation(controller_file):
|
||||||
|
epoch = int(controller_file.split('_')[1].split('.')[0])
|
||||||
|
|
||||||
|
# Load controller
|
||||||
|
controller = PendulumController()
|
||||||
|
controller.load_state_dict(torch.load(os.path.join(controller_dir, controller_file)))
|
||||||
|
controller.eval()
|
||||||
|
|
||||||
|
# Run simulation
|
||||||
|
state = np.array([theta0, omega0, alpha0])
|
||||||
|
theta_vals = []
|
||||||
|
|
||||||
|
for _ in range(num_steps):
|
||||||
|
theta_vals.append(state[0])
|
||||||
|
state = pendulum_ode_step(state, dt, desired_theta, controller)
|
||||||
|
|
||||||
|
return epoch, theta_vals # Return epoch with data
|
||||||
|
|
||||||
|
# Parallel processing
|
||||||
|
if __name__ == "__main__":
|
||||||
|
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
|
||||||
|
print(f"Using {num_workers} parallel workers...")
|
||||||
|
|
||||||
|
with Pool(processes=num_workers) as pool:
|
||||||
|
results = pool.map(run_simulation, selected_controllers)
|
||||||
|
|
||||||
|
# Sort results by epoch to ensure correct order
|
||||||
|
results.sort(key=lambda x: x[0])
|
||||||
|
epochs, theta_over_epochs = zip(*results) # Unzip sorted results
|
||||||
|
|
||||||
|
# Convert results to NumPy arrays
|
||||||
|
theta_over_epochs = np.array(theta_over_epochs)
|
||||||
|
|
||||||
|
# Create 3D line plot
|
||||||
|
fig = plt.figure(figsize=(10, 7))
|
||||||
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
|
|
||||||
|
time_steps = np.arange(num_steps) * dt # X-axis (time)
|
||||||
|
|
||||||
|
# Plot each controller as a separate line
|
||||||
|
for epoch, theta_vals in zip(epochs, theta_over_epochs):
|
||||||
|
ax.plot(
|
||||||
|
[epoch] * len(time_steps), # Y-axis (epoch stays constant for each line)
|
||||||
|
time_steps, # X-axis (time)
|
||||||
|
theta_vals, # Z-axis (theta evolution)
|
||||||
|
label=f"Epoch {epoch}" if epoch % (N * 10) == 0 else "", # Label some lines for clarity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
ax.set_xlabel("Epoch")
|
||||||
|
ax.set_ylabel("Time (s)")
|
||||||
|
ax.set_zlabel("Theta (rad)")
|
||||||
|
ax.set_title(f"Pendulum Angle Evolution for {loss_function}")
|
||||||
|
|
||||||
|
# Add a horizontal line at desired_theta across all epochs and time steps
|
||||||
|
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
|
||||||
|
ax.plot(
|
||||||
|
epochs_array, # X-axis spanning all epochs
|
||||||
|
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
|
||||||
|
[desired_theta] * len(epochs_array), # Constant Z-axis value of desired_theta
|
||||||
|
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Improve visibility
|
||||||
|
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
|
||||||
|
|
||||||
|
plt.savefig(f"{loss_function}.png", dpi=600)
|
||||||
|
#plt.show()
|
||||||
|
print(f"Saved plot as '{loss_function}.png'.")
|
||||||
|
Before Width: | Height: | Size: 9.4 MiB |
|
Before Width: | Height: | Size: 8.7 MiB |
|
Before Width: | Height: | Size: 8.3 MiB |
|
Before Width: | Height: | Size: 7.5 MiB |
BIN
analysis/max_normalized/extreme_perturbation/constant.png
Normal file
|
After Width: | Height: | Size: 564 KiB |
BIN
analysis/max_normalized/extreme_perturbation/exponential.png
Normal file
|
After Width: | Height: | Size: 561 KiB |
BIN
analysis/max_normalized/extreme_perturbation/inverse.png
Normal file
|
After Width: | Height: | Size: 524 KiB |
BIN
analysis/max_normalized/extreme_perturbation/inverse_squared.png
Normal file
|
After Width: | Height: | Size: 534 KiB |
BIN
analysis/max_normalized/extreme_perturbation/linear.png
Normal file
|
After Width: | Height: | Size: 546 KiB |
BIN
analysis/max_normalized/extreme_perturbation/quadratic.png
Normal file
|
After Width: | Height: | Size: 611 KiB |
BIN
analysis/max_normalized/large_perturbation/constant.png
Normal file
|
After Width: | Height: | Size: 733 KiB |
BIN
analysis/max_normalized/large_perturbation/exponential.png
Normal file
|
After Width: | Height: | Size: 832 KiB |
BIN
analysis/max_normalized/large_perturbation/inverse.png
Normal file
|
After Width: | Height: | Size: 794 KiB |
BIN
analysis/max_normalized/large_perturbation/inverse_squared.png
Normal file
|
After Width: | Height: | Size: 752 KiB |
BIN
analysis/max_normalized/large_perturbation/linear.png
Normal file
|
After Width: | Height: | Size: 707 KiB |
BIN
analysis/max_normalized/large_perturbation/quadratic.png
Normal file
|
After Width: | Height: | Size: 814 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/constant.png
Normal file
|
After Width: | Height: | Size: 494 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/exponential.png
Normal file
|
After Width: | Height: | Size: 578 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/inverse.png
Normal file
|
After Width: | Height: | Size: 505 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/inverse_squared.png
Normal file
|
After Width: | Height: | Size: 502 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/linear.png
Normal file
|
After Width: | Height: | Size: 518 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/quadratic.png
Normal file
|
After Width: | Height: | Size: 521 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/constant.png
Normal file
|
After Width: | Height: | Size: 517 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/exponential.png
Normal file
|
After Width: | Height: | Size: 538 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/inverse.png
Normal file
|
After Width: | Height: | Size: 508 KiB |
|
After Width: | Height: | Size: 526 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/linear.png
Normal file
|
After Width: | Height: | Size: 514 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/quadratic.png
Normal file
|
After Width: | Height: | Size: 559 KiB |
BIN
analysis/max_normalized/small_perturbation/constant.png
Normal file
|
After Width: | Height: | Size: 513 KiB |
BIN
analysis/max_normalized/small_perturbation/exponential.png
Normal file
|
After Width: | Height: | Size: 474 KiB |
BIN
analysis/max_normalized/small_perturbation/inverse.png
Normal file
|
After Width: | Height: | Size: 489 KiB |
BIN
analysis/max_normalized/small_perturbation/inverse_squared.png
Normal file
|
After Width: | Height: | Size: 498 KiB |
BIN
analysis/max_normalized/small_perturbation/linear.png
Normal file
|
After Width: | Height: | Size: 497 KiB |
BIN
analysis/max_normalized/small_perturbation/quadratic.png
Normal file
|
After Width: | Height: | Size: 503 KiB |
|
Before Width: | Height: | Size: 9.8 MiB |
17
training/normalized/PendulumController.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class PendulumController(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(4, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(64, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(64, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
raw_torque = self.net(x)
|
||||||
|
return torch.clamp(raw_torque, -250, 250)
|
||||||
26
training/normalized/PendulumDynamics.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class PendulumDynamics(nn.Module):
|
||||||
|
def __init__(self, controller, m:'float'=1, R:'float'=1, g:'float'=9.81):
|
||||||
|
super().__init__()
|
||||||
|
self.controller = controller
|
||||||
|
self.m: 'float' = m
|
||||||
|
self.R: 'float' = R
|
||||||
|
self.g: 'float' = g
|
||||||
|
|
||||||
|
def forward(self, t, state):
|
||||||
|
# Get the current values from the state
|
||||||
|
theta, omega, alpha, desired_theta = state[:, 0], state[:, 1], state[:, 2], state[:, 3]
|
||||||
|
|
||||||
|
# Make the input stack for the controller
|
||||||
|
input = torch.stack([theta, omega, alpha, desired_theta], dim=1)
|
||||||
|
|
||||||
|
# Get the torque (the output of the neural network)
|
||||||
|
tau = self.controller(input).squeeze(-1)
|
||||||
|
|
||||||
|
# Relax alpha
|
||||||
|
alpha_desired = (self.g / self.R) * torch.sin(theta) + tau / (self.m * self.R**2)
|
||||||
|
dalpha = alpha_desired - alpha
|
||||||
|
|
||||||
|
return torch.stack([omega, alpha, dalpha, torch.zeros_like(desired_theta)], dim=1)
|
||||||
BIN
training/normalized/__pycache__/PendulumDynamics.cpython-310.pyc
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1002
training/normalized/average_normalized/constant/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1002
training/normalized/average_normalized/exponential/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1002
training/normalized/average_normalized/inverse/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
@ -0,0 +1,608 @@
|
|||||||
|
Epoch,Loss
|
||||||
|
0,733.2171630859375
|
||||||
|
1,747.2186279296875
|
||||||
|
2,173.32571411132812
|
||||||
|
3,113.02052307128906
|
||||||
|
4,52.3023796081543
|
||||||
|
5,29.964460372924805
|
||||||
|
6,22.173477172851562
|
||||||
|
7,17.20752716064453
|
||||||
|
8,14.393138885498047
|
||||||
|
9,13.495450019836426
|
||||||
|
10,12.72439193725586
|
||||||
|
11,12.004809379577637
|
||||||
|
12,11.283649444580078
|
||||||
|
13,10.60710620880127
|
||||||
|
14,9.974536895751953
|
||||||
|
15,9.39864730834961
|
||||||
|
16,8.96867847442627
|
||||||
|
17,8.593585968017578
|
||||||
|
18,8.431589126586914
|
||||||
|
19,8.36359977722168
|
||||||
|
20,8.257978439331055
|
||||||
|
21,8.131260871887207
|
||||||
|
22,7.987368583679199
|
||||||
|
23,7.828018665313721
|
||||||
|
24,7.654465198516846
|
||||||
|
25,7.468804836273193
|
||||||
|
26,7.282363414764404
|
||||||
|
27,7.1038055419921875
|
||||||
|
28,6.926774024963379
|
||||||
|
29,6.756581783294678
|
||||||
|
30,6.605012893676758
|
||||||
|
31,6.480014324188232
|
||||||
|
32,6.387184143066406
|
||||||
|
33,6.323246955871582
|
||||||
|
34,6.280386924743652
|
||||||
|
35,6.2463226318359375
|
||||||
|
36,6.210468769073486
|
||||||
|
37,6.165651798248291
|
||||||
|
38,6.10935640335083
|
||||||
|
39,6.044148921966553
|
||||||
|
40,5.975587844848633
|
||||||
|
41,5.909937381744385
|
||||||
|
42,5.85064697265625
|
||||||
|
43,5.797428131103516
|
||||||
|
44,5.749576091766357
|
||||||
|
45,5.7068305015563965
|
||||||
|
46,5.668132305145264
|
||||||
|
47,5.631911754608154
|
||||||
|
48,5.5969085693359375
|
||||||
|
49,5.562211513519287
|
||||||
|
50,5.527288436889648
|
||||||
|
51,5.491824626922607
|
||||||
|
52,5.455735206604004
|
||||||
|
53,5.4193949699401855
|
||||||
|
54,5.383573055267334
|
||||||
|
55,5.350064754486084
|
||||||
|
56,5.319507598876953
|
||||||
|
57,5.289746284484863
|
||||||
|
58,5.259618759155273
|
||||||
|
59,5.228979110717773
|
||||||
|
60,5.198094844818115
|
||||||
|
61,5.167768955230713
|
||||||
|
62,5.1394476890563965
|
||||||
|
63,5.113860607147217
|
||||||
|
64,5.0905303955078125
|
||||||
|
65,5.070118427276611
|
||||||
|
66,5.05070686340332
|
||||||
|
67,5.031530857086182
|
||||||
|
68,5.012385845184326
|
||||||
|
69,4.993386268615723
|
||||||
|
70,4.974751949310303
|
||||||
|
71,4.956557273864746
|
||||||
|
72,4.939281940460205
|
||||||
|
73,4.923177719116211
|
||||||
|
74,4.9080328941345215
|
||||||
|
75,4.893461227416992
|
||||||
|
76,4.879059791564941
|
||||||
|
77,4.86448860168457
|
||||||
|
78,4.8493499755859375
|
||||||
|
79,4.833917617797852
|
||||||
|
80,4.818392753601074
|
||||||
|
81,4.80300760269165
|
||||||
|
82,4.788031101226807
|
||||||
|
83,4.773463249206543
|
||||||
|
84,4.759511947631836
|
||||||
|
85,4.746511459350586
|
||||||
|
86,4.734636306762695
|
||||||
|
87,4.72379207611084
|
||||||
|
88,4.713471412658691
|
||||||
|
89,4.7037153244018555
|
||||||
|
90,4.694935321807861
|
||||||
|
91,4.687223434448242
|
||||||
|
92,4.680534362792969
|
||||||
|
93,4.674868106842041
|
||||||
|
94,4.670145511627197
|
||||||
|
95,4.666247367858887
|
||||||
|
96,4.6630682945251465
|
||||||
|
97,4.660345554351807
|
||||||
|
98,4.6577630043029785
|
||||||
|
99,4.655178070068359
|
||||||
|
100,4.65253210067749
|
||||||
|
101,4.649806499481201
|
||||||
|
102,4.64708137512207
|
||||||
|
103,4.644402980804443
|
||||||
|
104,4.641839504241943
|
||||||
|
105,4.639404296875
|
||||||
|
106,4.637064456939697
|
||||||
|
107,4.634809970855713
|
||||||
|
108,4.632608890533447
|
||||||
|
109,4.63045597076416
|
||||||
|
110,4.6283416748046875
|
||||||
|
111,4.626272201538086
|
||||||
|
112,4.624259948730469
|
||||||
|
113,4.622305870056152
|
||||||
|
114,4.620428085327148
|
||||||
|
115,4.618651866912842
|
||||||
|
116,4.616915702819824
|
||||||
|
117,4.615218162536621
|
||||||
|
118,4.613545894622803
|
||||||
|
119,4.6118950843811035
|
||||||
|
120,4.610260963439941
|
||||||
|
121,4.608641147613525
|
||||||
|
122,4.607039928436279
|
||||||
|
123,4.6054606437683105
|
||||||
|
124,4.603922367095947
|
||||||
|
125,4.602435111999512
|
||||||
|
126,4.601007461547852
|
||||||
|
127,4.599658966064453
|
||||||
|
128,4.598391532897949
|
||||||
|
129,4.59716796875
|
||||||
|
130,4.595986843109131
|
||||||
|
131,4.594844341278076
|
||||||
|
132,4.593727111816406
|
||||||
|
133,4.592639446258545
|
||||||
|
134,4.591578960418701
|
||||||
|
135,4.590541839599609
|
||||||
|
136,4.589527130126953
|
||||||
|
137,4.588528633117676
|
||||||
|
138,4.587544918060303
|
||||||
|
139,4.586572647094727
|
||||||
|
140,4.585610389709473
|
||||||
|
141,4.58465576171875
|
||||||
|
142,4.583714008331299
|
||||||
|
143,4.582784652709961
|
||||||
|
144,4.58186149597168
|
||||||
|
145,4.580949783325195
|
||||||
|
146,4.58004903793335
|
||||||
|
147,4.579157829284668
|
||||||
|
148,4.578279972076416
|
||||||
|
149,4.5774149894714355
|
||||||
|
150,4.5765604972839355
|
||||||
|
151,4.575720310211182
|
||||||
|
152,4.574892044067383
|
||||||
|
153,4.574074745178223
|
||||||
|
154,4.573271751403809
|
||||||
|
155,4.57248067855835
|
||||||
|
156,4.571700096130371
|
||||||
|
157,4.570932388305664
|
||||||
|
158,4.570174694061279
|
||||||
|
159,4.569428443908691
|
||||||
|
160,4.568695545196533
|
||||||
|
161,4.567972183227539
|
||||||
|
162,4.567262649536133
|
||||||
|
163,4.566561698913574
|
||||||
|
164,4.565873622894287
|
||||||
|
165,4.565197944641113
|
||||||
|
166,4.564536094665527
|
||||||
|
167,4.563884258270264
|
||||||
|
168,4.563241958618164
|
||||||
|
169,4.562610626220703
|
||||||
|
170,4.561987400054932
|
||||||
|
171,4.561375141143799
|
||||||
|
172,4.56077241897583
|
||||||
|
173,4.560176849365234
|
||||||
|
174,4.559585094451904
|
||||||
|
175,4.558999061584473
|
||||||
|
176,4.558421611785889
|
||||||
|
177,4.557851791381836
|
||||||
|
178,4.55728816986084
|
||||||
|
179,4.556730270385742
|
||||||
|
180,4.556179523468018
|
||||||
|
181,4.555635452270508
|
||||||
|
182,4.555098533630371
|
||||||
|
183,4.554568767547607
|
||||||
|
184,4.554043769836426
|
||||||
|
185,4.553525447845459
|
||||||
|
186,4.553014755249023
|
||||||
|
187,4.552509307861328
|
||||||
|
188,4.552009105682373
|
||||||
|
189,4.551515102386475
|
||||||
|
190,4.551027297973633
|
||||||
|
191,4.550544738769531
|
||||||
|
192,4.55006742477417
|
||||||
|
193,4.549595832824707
|
||||||
|
194,4.549131393432617
|
||||||
|
195,4.5486741065979
|
||||||
|
196,4.548222064971924
|
||||||
|
197,4.547779083251953
|
||||||
|
198,4.547341823577881
|
||||||
|
199,4.546910762786865
|
||||||
|
200,4.546485424041748
|
||||||
|
201,4.546065330505371
|
||||||
|
202,4.545650959014893
|
||||||
|
203,4.545242786407471
|
||||||
|
204,4.544842720031738
|
||||||
|
205,4.5444488525390625
|
||||||
|
206,4.544060230255127
|
||||||
|
207,4.54367733001709
|
||||||
|
208,4.543309211730957
|
||||||
|
209,4.5429511070251465
|
||||||
|
210,4.542597770690918
|
||||||
|
211,4.542244911193848
|
||||||
|
212,4.541895389556885
|
||||||
|
213,4.541547775268555
|
||||||
|
214,4.541203498840332
|
||||||
|
215,4.540863513946533
|
||||||
|
216,4.540529727935791
|
||||||
|
217,4.5402045249938965
|
||||||
|
218,4.539884090423584
|
||||||
|
219,4.539568901062012
|
||||||
|
220,4.539254665374756
|
||||||
|
221,4.538944721221924
|
||||||
|
222,4.538638591766357
|
||||||
|
223,4.538336277008057
|
||||||
|
224,4.538040637969971
|
||||||
|
225,4.537751197814941
|
||||||
|
226,4.537464618682861
|
||||||
|
227,4.537179470062256
|
||||||
|
228,4.536898612976074
|
||||||
|
229,4.536619186401367
|
||||||
|
230,4.536343574523926
|
||||||
|
231,4.536070823669434
|
||||||
|
232,4.535802364349365
|
||||||
|
233,4.535538196563721
|
||||||
|
234,4.535276889801025
|
||||||
|
235,4.535017967224121
|
||||||
|
236,4.534761428833008
|
||||||
|
237,4.534510135650635
|
||||||
|
238,4.5342607498168945
|
||||||
|
239,4.534013271331787
|
||||||
|
240,4.533768653869629
|
||||||
|
241,4.533526420593262
|
||||||
|
242,4.533287525177002
|
||||||
|
243,4.533048629760742
|
||||||
|
244,4.532813549041748
|
||||||
|
245,4.532580375671387
|
||||||
|
246,4.532350063323975
|
||||||
|
247,4.5321245193481445
|
||||||
|
248,4.531898021697998
|
||||||
|
249,4.531674861907959
|
||||||
|
250,4.531455039978027
|
||||||
|
251,4.531236171722412
|
||||||
|
252,4.531019687652588
|
||||||
|
253,4.530806064605713
|
||||||
|
254,4.530594348907471
|
||||||
|
255,4.530384540557861
|
||||||
|
256,4.530177593231201
|
||||||
|
257,4.529972076416016
|
||||||
|
258,4.529767990112305
|
||||||
|
259,4.529566287994385
|
||||||
|
260,4.529365539550781
|
||||||
|
261,4.529166221618652
|
||||||
|
262,4.5289692878723145
|
||||||
|
263,4.528773784637451
|
||||||
|
264,4.5285797119140625
|
||||||
|
265,4.52838659286499
|
||||||
|
266,4.528195858001709
|
||||||
|
267,4.528006553649902
|
||||||
|
268,4.52781867980957
|
||||||
|
269,4.527633190155029
|
||||||
|
270,4.527448654174805
|
||||||
|
271,4.527266502380371
|
||||||
|
272,4.527083396911621
|
||||||
|
273,4.52690315246582
|
||||||
|
274,4.526723861694336
|
||||||
|
275,4.526546001434326
|
||||||
|
276,4.526369571685791
|
||||||
|
277,4.526192665100098
|
||||||
|
278,4.526019096374512
|
||||||
|
279,4.525848388671875
|
||||||
|
280,4.525679588317871
|
||||||
|
281,4.5255126953125
|
||||||
|
282,4.5253472328186035
|
||||||
|
283,4.525182723999023
|
||||||
|
284,4.525018692016602
|
||||||
|
285,4.524855613708496
|
||||||
|
286,4.524693489074707
|
||||||
|
287,4.524532794952393
|
||||||
|
288,4.524373531341553
|
||||||
|
289,4.524214267730713
|
||||||
|
290,4.5240559577941895
|
||||||
|
291,4.523898601531982
|
||||||
|
292,4.523742198944092
|
||||||
|
293,4.523586273193359
|
||||||
|
294,4.523431301116943
|
||||||
|
295,4.52327823638916
|
||||||
|
296,4.523126602172852
|
||||||
|
297,4.522974491119385
|
||||||
|
298,4.522823333740234
|
||||||
|
299,4.5226731300354
|
||||||
|
300,4.522523403167725
|
||||||
|
301,4.522375106811523
|
||||||
|
302,4.522228240966797
|
||||||
|
303,4.522083282470703
|
||||||
|
304,4.521938800811768
|
||||||
|
305,4.521796703338623
|
||||||
|
306,4.521655559539795
|
||||||
|
307,4.521514415740967
|
||||||
|
308,4.521374225616455
|
||||||
|
309,4.521234035491943
|
||||||
|
310,4.521094799041748
|
||||||
|
311,4.520956993103027
|
||||||
|
312,4.520818710327148
|
||||||
|
313,4.520682334899902
|
||||||
|
314,4.520545959472656
|
||||||
|
315,4.520410060882568
|
||||||
|
316,4.520275115966797
|
||||||
|
317,4.5201416015625
|
||||||
|
318,4.520009517669678
|
||||||
|
319,4.519878387451172
|
||||||
|
320,4.519747734069824
|
||||||
|
321,4.519618034362793
|
||||||
|
322,4.51948881149292
|
||||||
|
323,4.519359588623047
|
||||||
|
324,4.519231796264648
|
||||||
|
325,4.519104957580566
|
||||||
|
326,4.518978118896484
|
||||||
|
327,4.518850803375244
|
||||||
|
328,4.5187249183654785
|
||||||
|
329,4.518599510192871
|
||||||
|
330,4.518474578857422
|
||||||
|
331,4.518350601196289
|
||||||
|
332,4.518225193023682
|
||||||
|
333,4.518101215362549
|
||||||
|
334,4.517977237701416
|
||||||
|
335,4.517852783203125
|
||||||
|
336,4.51772928237915
|
||||||
|
337,4.517606258392334
|
||||||
|
338,4.517481803894043
|
||||||
|
339,4.517359733581543
|
||||||
|
340,4.517236709594727
|
||||||
|
341,4.517114162445068
|
||||||
|
342,4.51699161529541
|
||||||
|
343,4.516868591308594
|
||||||
|
344,4.516744613647461
|
||||||
|
345,4.516623020172119
|
||||||
|
346,4.516499996185303
|
||||||
|
347,4.5163774490356445
|
||||||
|
348,4.516254901885986
|
||||||
|
349,4.5161333084106445
|
||||||
|
350,4.5160112380981445
|
||||||
|
351,4.5158891677856445
|
||||||
|
352,4.515769004821777
|
||||||
|
353,4.5156474113464355
|
||||||
|
354,4.51552677154541
|
||||||
|
355,4.515407085418701
|
||||||
|
356,4.515288829803467
|
||||||
|
357,4.515170574188232
|
||||||
|
358,4.515053749084473
|
||||||
|
359,4.5149383544921875
|
||||||
|
360,4.514822483062744
|
||||||
|
361,4.514708042144775
|
||||||
|
362,4.5145955085754395
|
||||||
|
363,4.514482498168945
|
||||||
|
364,4.514371395111084
|
||||||
|
365,4.5142598152160645
|
||||||
|
366,4.5141496658325195
|
||||||
|
367,4.514039993286133
|
||||||
|
368,4.513929843902588
|
||||||
|
369,4.513820171356201
|
||||||
|
370,4.513709545135498
|
||||||
|
371,4.513599872589111
|
||||||
|
372,4.51348876953125
|
||||||
|
373,4.5133771896362305
|
||||||
|
374,4.513265132904053
|
||||||
|
375,4.513152122497559
|
||||||
|
376,4.51303768157959
|
||||||
|
377,4.512922763824463
|
||||||
|
378,4.512807846069336
|
||||||
|
379,4.512691974639893
|
||||||
|
380,4.512577056884766
|
||||||
|
381,4.512462139129639
|
||||||
|
382,4.512347221374512
|
||||||
|
383,4.512232780456543
|
||||||
|
384,4.512118339538574
|
||||||
|
385,4.5120038986206055
|
||||||
|
386,4.511888027191162
|
||||||
|
387,4.511774063110352
|
||||||
|
388,4.511661052703857
|
||||||
|
389,4.5115485191345215
|
||||||
|
390,4.511435508728027
|
||||||
|
391,4.511322975158691
|
||||||
|
392,4.5112104415893555
|
||||||
|
393,4.511097431182861
|
||||||
|
394,4.510985851287842
|
||||||
|
395,4.510873317718506
|
||||||
|
396,4.510761260986328
|
||||||
|
397,4.510649681091309
|
||||||
|
398,4.510537624359131
|
||||||
|
399,4.510426998138428
|
||||||
|
400,4.510315895080566
|
||||||
|
401,4.510205268859863
|
||||||
|
402,4.510096073150635
|
||||||
|
403,4.509986400604248
|
||||||
|
404,4.5098772048950195
|
||||||
|
405,4.509767532348633
|
||||||
|
406,4.509658336639404
|
||||||
|
407,4.509550094604492
|
||||||
|
408,4.509440898895264
|
||||||
|
409,4.509332180023193
|
||||||
|
410,4.509223937988281
|
||||||
|
411,4.509115219116211
|
||||||
|
412,4.509006500244141
|
||||||
|
413,4.508897304534912
|
||||||
|
414,4.508789539337158
|
||||||
|
415,4.508681774139404
|
||||||
|
416,4.508574485778809
|
||||||
|
417,4.508468151092529
|
||||||
|
418,4.508361339569092
|
||||||
|
419,4.5082550048828125
|
||||||
|
420,4.508149147033691
|
||||||
|
421,4.5080437660217285
|
||||||
|
422,4.507938385009766
|
||||||
|
423,4.5078325271606445
|
||||||
|
424,4.507728099822998
|
||||||
|
425,4.507623672485352
|
||||||
|
426,4.507519245147705
|
||||||
|
427,4.507416725158691
|
||||||
|
428,4.5073137283325195
|
||||||
|
429,4.507210731506348
|
||||||
|
430,4.50710916519165
|
||||||
|
431,4.507007122039795
|
||||||
|
432,4.506906032562256
|
||||||
|
433,4.506804466247559
|
||||||
|
434,4.506704330444336
|
||||||
|
435,4.50660514831543
|
||||||
|
436,4.506505966186523
|
||||||
|
437,4.506407260894775
|
||||||
|
438,4.5063090324401855
|
||||||
|
439,4.50621223449707
|
||||||
|
440,4.506114959716797
|
||||||
|
441,4.50601863861084
|
||||||
|
442,4.505922794342041
|
||||||
|
443,4.5058274269104
|
||||||
|
444,4.50573205947876
|
||||||
|
445,4.505638122558594
|
||||||
|
446,4.505544185638428
|
||||||
|
447,4.505451679229736
|
||||||
|
448,4.505359172821045
|
||||||
|
449,4.505267143249512
|
||||||
|
450,4.505176067352295
|
||||||
|
451,4.505085468292236
|
||||||
|
452,4.504995822906494
|
||||||
|
453,4.504907131195068
|
||||||
|
454,4.504818439483643
|
||||||
|
455,4.504730701446533
|
||||||
|
456,4.504644393920898
|
||||||
|
457,4.504558086395264
|
||||||
|
458,4.504471778869629
|
||||||
|
459,4.504387378692627
|
||||||
|
460,4.504302024841309
|
||||||
|
461,4.5042195320129395
|
||||||
|
462,4.504136085510254
|
||||||
|
463,4.504052639007568
|
||||||
|
464,4.503970623016357
|
||||||
|
465,4.503889083862305
|
||||||
|
466,4.503807544708252
|
||||||
|
467,4.503726005554199
|
||||||
|
468,4.503644943237305
|
||||||
|
469,4.503564357757568
|
||||||
|
470,4.503483295440674
|
||||||
|
471,4.503403186798096
|
||||||
|
472,4.503323078155518
|
||||||
|
473,4.503242015838623
|
||||||
|
474,4.5031633377075195
|
||||||
|
475,4.5030837059021
|
||||||
|
476,4.50300407409668
|
||||||
|
477,4.502924919128418
|
||||||
|
478,4.502845764160156
|
||||||
|
479,4.5027666091918945
|
||||||
|
480,4.502687931060791
|
||||||
|
481,4.5026092529296875
|
||||||
|
482,4.502531051635742
|
||||||
|
483,4.502452373504639
|
||||||
|
484,4.502374172210693
|
||||||
|
485,4.502295970916748
|
||||||
|
486,4.502218246459961
|
||||||
|
487,4.502140045166016
|
||||||
|
488,4.50206184387207
|
||||||
|
489,4.501983642578125
|
||||||
|
490,4.501905918121338
|
||||||
|
491,4.501828193664551
|
||||||
|
492,4.5017499923706055
|
||||||
|
493,4.501672744750977
|
||||||
|
494,4.5015950202941895
|
||||||
|
495,4.501518249511719
|
||||||
|
496,4.501441955566406
|
||||||
|
497,4.501365661621094
|
||||||
|
498,4.5012898445129395
|
||||||
|
499,4.501214981079102
|
||||||
|
500,4.501140117645264
|
||||||
|
501,4.501065254211426
|
||||||
|
502,4.500990867614746
|
||||||
|
503,4.500916004180908
|
||||||
|
504,4.5008416175842285
|
||||||
|
505,4.500767707824707
|
||||||
|
506,4.5006937980651855
|
||||||
|
507,4.500619411468506
|
||||||
|
508,4.500545978546143
|
||||||
|
509,4.500472068786621
|
||||||
|
510,4.5003981590271
|
||||||
|
511,4.500324249267578
|
||||||
|
512,4.500250816345215
|
||||||
|
513,4.500176906585693
|
||||||
|
514,4.50010347366333
|
||||||
|
515,4.500030517578125
|
||||||
|
516,4.4999566078186035
|
||||||
|
517,4.499882698059082
|
||||||
|
518,4.4998087882995605
|
||||||
|
519,4.499735355377197
|
||||||
|
520,4.499661445617676
|
||||||
|
521,4.499587535858154
|
||||||
|
522,4.499513149261475
|
||||||
|
523,4.499438762664795
|
||||||
|
524,4.499364852905273
|
||||||
|
525,4.499290466308594
|
||||||
|
526,4.499216079711914
|
||||||
|
527,4.499140739440918
|
||||||
|
528,4.499065399169922
|
||||||
|
529,4.498990058898926
|
||||||
|
530,4.498913764953613
|
||||||
|
531,4.498837947845459
|
||||||
|
532,4.498762607574463
|
||||||
|
533,4.49868631362915
|
||||||
|
534,4.4986114501953125
|
||||||
|
535,4.498536109924316
|
||||||
|
536,4.498461723327637
|
||||||
|
537,4.498386859893799
|
||||||
|
538,4.498311996459961
|
||||||
|
539,4.4982380867004395
|
||||||
|
540,4.498164653778076
|
||||||
|
541,4.498090744018555
|
||||||
|
542,4.498019218444824
|
||||||
|
543,4.497945308685303
|
||||||
|
544,4.497872829437256
|
||||||
|
545,4.497801303863525
|
||||||
|
546,4.497729301452637
|
||||||
|
547,4.49765682220459
|
||||||
|
548,4.497586250305176
|
||||||
|
549,4.49751615524292
|
||||||
|
550,4.497445106506348
|
||||||
|
551,4.49737548828125
|
||||||
|
552,4.497305870056152
|
||||||
|
553,4.49723482131958
|
||||||
|
554,4.497164726257324
|
||||||
|
555,4.497094631195068
|
||||||
|
556,4.497025012969971
|
||||||
|
557,4.496954441070557
|
||||||
|
558,4.496883392333984
|
||||||
|
559,4.496815204620361
|
||||||
|
560,4.4967451095581055
|
||||||
|
561,4.496673107147217
|
||||||
|
562,4.496603488922119
|
||||||
|
563,4.4965338706970215
|
||||||
|
564,4.496462821960449
|
||||||
|
565,4.496394157409668
|
||||||
|
566,4.496323585510254
|
||||||
|
567,4.496251583099365
|
||||||
|
568,4.496183395385742
|
||||||
|
569,4.49611234664917
|
||||||
|
570,4.496041774749756
|
||||||
|
571,4.4959716796875
|
||||||
|
572,4.495903491973877
|
||||||
|
573,4.4958343505859375
|
||||||
|
574,4.49576473236084
|
||||||
|
575,4.495695114135742
|
||||||
|
576,4.4956278800964355
|
||||||
|
577,4.4955573081970215
|
||||||
|
578,4.495490074157715
|
||||||
|
579,4.495420932769775
|
||||||
|
580,4.495352745056152
|
||||||
|
581,4.495282173156738
|
||||||
|
582,4.495216369628906
|
||||||
|
583,4.495147228240967
|
||||||
|
584,4.495075225830078
|
||||||
|
585,4.4950079917907715
|
||||||
|
586,4.494940757751465
|
||||||
|
587,4.494872093200684
|
||||||
|
588,4.494802951812744
|
||||||
|
589,4.494732856750488
|
||||||
|
590,4.494668483734131
|
||||||
|
591,4.494600296020508
|
||||||
|
592,4.494530200958252
|
||||||
|
593,4.494461536407471
|
||||||
|
594,4.494396209716797
|
||||||
|
595,4.49432897567749
|
||||||
|
596,4.494260787963867
|
||||||
|
597,4.494192123413086
|
||||||
|
598,4.4941229820251465
|
||||||
|
599,4.494059085845947
|
||||||
|
600,4.493993759155273
|
||||||
|
601,4.493923664093018
|
||||||
|
602,4.493853569030762
|
||||||
|
603,4.493789196014404
|
||||||
|
604,4.493722915649414
|
||||||
|
605,4.493655681610107
|
||||||
|
606,4.493587970733643
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1002
training/normalized/average_normalized/linear/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1002
training/normalized/average_normalized/quadratic/training_log.csv
Normal file
155
training/normalized/average_normalized_trainer.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchdiffeq import odeint
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import csv
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from PendulumController import PendulumController
|
||||||
|
from PendulumDynamics import PendulumDynamics
|
||||||
|
|
||||||
|
# Device setup
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# Initial conditions (theta0, omega0, alpha0, desired_theta)
|
||||||
|
from initial_conditions import initial_conditions
|
||||||
|
state_0 = torch.tensor(initial_conditions, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# Device setup
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
m = 10.0
|
||||||
|
g = 9.81
|
||||||
|
R = 1.0
|
||||||
|
|
||||||
|
# Time grid
|
||||||
|
t_start, t_end, t_points = 0, 10, 1000
|
||||||
|
t_span = torch.linspace(t_start, t_end, t_points, device=device)
|
||||||
|
|
||||||
|
# Specify directory for storing results
|
||||||
|
output_dir = "average_normalized"
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Use a previously generated random seed
|
||||||
|
random_seed = 4529
|
||||||
|
|
||||||
|
# Set the seeds for reproducibility
|
||||||
|
torch.manual_seed(random_seed)
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
|
||||||
|
# Print the chosen random seed
|
||||||
|
print(f"Random seed for torch and numpy: {random_seed}")
|
||||||
|
|
||||||
|
# Initialize controller and dynamics
|
||||||
|
controller = PendulumController().to(device)
|
||||||
|
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
|
||||||
|
|
||||||
|
# Optimizer setup
|
||||||
|
learning_rate = 1e-1
|
||||||
|
weight_decay = 1e-4
|
||||||
|
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
num_epochs = 1001
|
||||||
|
|
||||||
|
# Define loss functions
|
||||||
|
def make_loss_fn(weight_fn):
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
return loss_fn
|
||||||
|
|
||||||
|
# Define and store weight functions with descriptions, normalized by average weight
|
||||||
|
weight_functions = {
|
||||||
|
'constant': {
|
||||||
|
'function': lambda t: torch.ones_like(t) / torch.ones_like(t).mean(),
|
||||||
|
'description': 'Constant weight: All weights are 1, normalized by the average (remains 1)'
|
||||||
|
},
|
||||||
|
'linear': {
|
||||||
|
'function': lambda t: (t / t.max()) / (t / t.max()).mean(),
|
||||||
|
'description': 'Linear weight: Weights increase linearly from 0 to 1, normalized by the average weight'
|
||||||
|
},
|
||||||
|
'quadratic': {
|
||||||
|
'function': lambda t: ((t / t.max()) ** 2) / ((t / t.max()) ** 2).mean(),
|
||||||
|
'description': 'Quadratic weight: Weights increase quadratically from 0 to 1, normalized by the average weight'
|
||||||
|
},
|
||||||
|
'exponential': {
|
||||||
|
'function': lambda t: (torch.exp(t / t.max() * 2)) / (torch.exp(t / t.max() * 2)).mean(),
|
||||||
|
'description': 'Exponential weight: Weights increase exponentially, normalized by the average weight'
|
||||||
|
},
|
||||||
|
'inverse': {
|
||||||
|
'function': lambda t: (1 / (t / t.max() + 1)) / (1 / (t / t.max() + 1)).mean(),
|
||||||
|
'description': 'Inverse weight: Weights decrease inversely, normalized by the average weight'
|
||||||
|
},
|
||||||
|
'inverse_squared': {
|
||||||
|
'function': lambda t: (1 / ((t / t.max() + 1) ** 2)) / (1 / ((t / t.max() + 1) ** 2)).mean(),
|
||||||
|
'description': 'Inverse squared weight: Weights decrease inversely squared, normalized by the average weight'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Training loop for each weight function
|
||||||
|
for name, weight_info in weight_functions.items():
|
||||||
|
controller = PendulumController().to(device)
|
||||||
|
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
|
||||||
|
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
loss_fn = make_loss_fn(weight_info['function'])
|
||||||
|
|
||||||
|
# File paths
|
||||||
|
function_output_dir = os.path.join(output_dir, name)
|
||||||
|
controllers_dir = os.path.join(function_output_dir, "controllers")
|
||||||
|
|
||||||
|
# Check if controllers directory exists and remove it
|
||||||
|
if os.path.exists(controllers_dir):
|
||||||
|
shutil.rmtree(controllers_dir)
|
||||||
|
os.makedirs(controllers_dir, exist_ok=True)
|
||||||
|
|
||||||
|
config_file = os.path.join(function_output_dir, "training_config.txt")
|
||||||
|
log_file = os.path.join(function_output_dir, "training_log.csv")
|
||||||
|
|
||||||
|
# Overwrite configuration and log files
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
f.write(f"Random Seed: {random_seed}\n")
|
||||||
|
f.write(f"Time Span: {t_start} to {t_end}, Points: {t_points}\n")
|
||||||
|
f.write(f"Learning Rate: {learning_rate}\n")
|
||||||
|
f.write(f"Weight Decay: {weight_decay}\n")
|
||||||
|
f.write("\nLoss Function:\n")
|
||||||
|
f.write(inspect.getsource(loss_fn))
|
||||||
|
f.write("\nTraining Cases:\n")
|
||||||
|
f.write("[theta0, omega0, alpha0, desired_theta]\n")
|
||||||
|
for case in state_0.cpu().numpy():
|
||||||
|
f.write(f"{case.tolist()}\n")
|
||||||
|
|
||||||
|
with open(log_file, "w", newline="") as csvfile:
|
||||||
|
csv_writer = csv.writer(csvfile)
|
||||||
|
csv_writer.writerow(["Epoch", "Loss"])
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
state_traj = odeint(pendulum_dynamics, state_0, t_span, method='rk4')
|
||||||
|
loss = loss_fn(state_traj, t_span)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
with open(log_file, "a", newline="") as csvfile:
|
||||||
|
csv_writer = csv.writer(csvfile)
|
||||||
|
csv_writer.writerow([epoch, loss.item()])
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
model_file = os.path.join(controllers_dir, f"controller_{epoch}.pth")
|
||||||
|
torch.save(controller.state_dict(), model_file)
|
||||||
|
print(f"{model_file} saved with loss: {loss}")
|
||||||
|
|
||||||
|
print("Training complete. Models and logs are saved under respective directories for each loss function.")
|
||||||
27
training/normalized/initial_conditions.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
from torch import pi
|
||||||
|
|
||||||
|
initial_conditions = [
|
||||||
|
[1/6 * pi, 0.0, 0.0, 0.0],
|
||||||
|
[-1/6 * pi, 0.0, 0.0, 0.0],
|
||||||
|
[2/3 * pi, 0.0, 0.0, 0.0],
|
||||||
|
[-2/3 * pi, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1/3 * pi, 0.0, 0.0],
|
||||||
|
[0.0, -1/3 * pi, 0.0, 0.0],
|
||||||
|
[0.0, 2 * pi, 0.0, 0.0],
|
||||||
|
[0.0, -2 * pi, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 2 * pi],
|
||||||
|
[0.0, 0.0, 0.0, -2 * pi],
|
||||||
|
[0.0, 0.0, 0.0, 1/2 * pi],
|
||||||
|
[0.0, 0.0, 0.0, -1/2 * pi],
|
||||||
|
[0.0, 0.0, 0.0, 1/3 * pi],
|
||||||
|
[0.0, 0.0, 0.0, -1/3 * pi],
|
||||||
|
[1/4 * pi, 1 * pi, 0.0, 0.0],
|
||||||
|
[-1/4 * pi, -1 * pi, 0.0, 0.0],
|
||||||
|
[1/2 * pi, -1 * pi, 0.0, 1/3 * pi],
|
||||||
|
[-1/2 * pi, 1 * pi, 0.0, -1/3 * pi],
|
||||||
|
[1/4 * pi, 1 * pi, 0.0, 2 * pi],
|
||||||
|
[-1/4 * pi, -1 * pi, 0.0, 2 * pi],
|
||||||
|
[1/2 * pi, -1 * pi, 0.0, 4 * pi],
|
||||||
|
[-1/2 * pi, 1 * pi, 0.0, -4 * pi],
|
||||||
|
]
|
||||||
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1001
training/normalized/max_normalized/constant/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1001
training/normalized/max_normalized/exponential/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1001
training/normalized/max_normalized/inverse/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1001
training/normalized/max_normalized/inverse_squared/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1001
training/normalized/max_normalized/linear/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
Random Seed: 4529
|
||||||
|
Time Span: 0 to 10, Points: 1000
|
||||||
|
Learning Rate: 0.1
|
||||||
|
Weight Decay: 0.0001
|
||||||
|
|
||||||
|
Loss Function:
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
Training Cases:
|
||||||
|
[theta0, omega0, alpha0, desired_theta]
|
||||||
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||||
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||||
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||||
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||||
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||||
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||||
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||||
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||||
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||||
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||||
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||||
1001
training/normalized/max_normalized/quadratic/training_log.csv
Normal file
155
training/normalized/max_normalized_trainer.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchdiffeq import odeint
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import csv
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from PendulumController import PendulumController
|
||||||
|
from PendulumDynamics import PendulumDynamics
|
||||||
|
|
||||||
|
# Device setup
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# Initial conditions (theta0, omega0, alpha0, desired_theta)
|
||||||
|
from initial_conditions import initial_conditions
|
||||||
|
state_0 = torch.tensor(initial_conditions, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# Device setup
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
m = 10.0
|
||||||
|
g = 9.81
|
||||||
|
R = 1.0
|
||||||
|
|
||||||
|
# Time grid
|
||||||
|
t_start, t_end, t_points = 0, 10, 1000
|
||||||
|
t_span = torch.linspace(t_start, t_end, t_points, device=device)
|
||||||
|
|
||||||
|
# Specify directory for storing results
|
||||||
|
output_dir = "training"
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Use a previously generated random seed
|
||||||
|
random_seed = 4529
|
||||||
|
|
||||||
|
# Set the seeds for reproducibility
|
||||||
|
torch.manual_seed(random_seed)
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
|
||||||
|
# Print the chosen random seed
|
||||||
|
print(f"Random seed for torch and numpy: {random_seed}")
|
||||||
|
|
||||||
|
# Initialize controller and dynamics
|
||||||
|
controller = PendulumController().to(device)
|
||||||
|
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
|
||||||
|
|
||||||
|
# Optimizer setup
|
||||||
|
learning_rate = 1e-1
|
||||||
|
weight_decay = 1e-4
|
||||||
|
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
num_epochs = 1000
|
||||||
|
|
||||||
|
# Define loss functions
|
||||||
|
def make_loss_fn(weight_fn):
|
||||||
|
def loss_fn(state_traj, t_span):
|
||||||
|
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||||
|
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||||
|
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||||
|
|
||||||
|
# Reshape or expand weights to match theta dimensions
|
||||||
|
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||||
|
|
||||||
|
# Calculate the weighted loss
|
||||||
|
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||||
|
|
||||||
|
return loss_fn
|
||||||
|
|
||||||
|
# Define and store weight functions with descriptions
|
||||||
|
weight_functions = {
|
||||||
|
'constant': {
|
||||||
|
'function': lambda t: torch.ones_like(t),
|
||||||
|
'description': 'Constant weight: All weights are 1'
|
||||||
|
},
|
||||||
|
'linear': {
|
||||||
|
'function': lambda t: (t / t.max()) / (t / t.max()).max(),
|
||||||
|
'description': 'Linear weight: Weights increase linearly from 0 to 1, normalized by max'
|
||||||
|
},
|
||||||
|
'quadratic': {
|
||||||
|
'function': lambda t: ((t / t.max()) ** 2) / ((t / t.max()) ** 2).max(),
|
||||||
|
'description': 'Quadratic weight: Weights increase quadratically from 0 to 1, normalized by max'
|
||||||
|
},
|
||||||
|
'exponential': {
|
||||||
|
'function': lambda t: (torch.exp(t / t.max() * 2)) / (torch.exp(t / t.max() * 2)).max(),
|
||||||
|
'description': 'Exponential weight: Weights increase exponentially, normalized by max'
|
||||||
|
},
|
||||||
|
'inverse': {
|
||||||
|
'function': lambda t: (1 / (t / t.max() + 1)) / (1 / (t / t.max() + 1)).max(),
|
||||||
|
'description': 'Inverse weight: Weights decrease inversely, normalized by max'
|
||||||
|
},
|
||||||
|
'inverse_squared': {
|
||||||
|
'function': lambda t: (1 / ((t / t.max() + 1) ** 2)) / (1 / ((t / t.max() + 1) ** 2)).max(),
|
||||||
|
'description': 'Inverse squared weight: Weights decrease inversely squared, normalized by max'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Training loop for each weight function
|
||||||
|
for name, weight_info in weight_functions.items():
|
||||||
|
controller = PendulumController().to(device)
|
||||||
|
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
|
||||||
|
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
loss_fn = make_loss_fn(weight_info['function'])
|
||||||
|
|
||||||
|
# File paths
|
||||||
|
function_output_dir = os.path.join(output_dir, name)
|
||||||
|
controllers_dir = os.path.join(function_output_dir, "controllers")
|
||||||
|
|
||||||
|
# Check if controllers directory exists and remove it
|
||||||
|
if os.path.exists(controllers_dir):
|
||||||
|
shutil.rmtree(controllers_dir)
|
||||||
|
os.makedirs(controllers_dir, exist_ok=True)
|
||||||
|
|
||||||
|
config_file = os.path.join(function_output_dir, "training_config.txt")
|
||||||
|
log_file = os.path.join(function_output_dir, "training_log.csv")
|
||||||
|
|
||||||
|
# Overwrite configuration and log files
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
f.write(f"Random Seed: {random_seed}\n")
|
||||||
|
f.write(f"Time Span: {t_start} to {t_end}, Points: {t_points}\n")
|
||||||
|
f.write(f"Learning Rate: {learning_rate}\n")
|
||||||
|
f.write(f"Weight Decay: {weight_decay}\n")
|
||||||
|
f.write("\nLoss Function:\n")
|
||||||
|
f.write(inspect.getsource(loss_fn))
|
||||||
|
f.write("\nTraining Cases:\n")
|
||||||
|
f.write("[theta0, omega0, alpha0, desired_theta]\n")
|
||||||
|
for case in state_0.cpu().numpy():
|
||||||
|
f.write(f"{case.tolist()}\n")
|
||||||
|
|
||||||
|
with open(log_file, "w", newline="") as csvfile:
|
||||||
|
csv_writer = csv.writer(csvfile)
|
||||||
|
csv_writer.writerow(["Epoch", "Loss"])
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
state_traj = odeint(pendulum_dynamics, state_0, t_span, method='rk4')
|
||||||
|
loss = loss_fn(state_traj, t_span)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
with open(log_file, "a", newline="") as csvfile:
|
||||||
|
csv_writer = csv.writer(csvfile)
|
||||||
|
csv_writer.writerow([epoch, loss.item()])
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
model_file = os.path.join(controllers_dir, f"controller_{epoch}.pth")
|
||||||
|
torch.save(controller.state_dict(), model_file)
|
||||||
|
print(f"{model_file} saved with loss: {loss}")
|
||||||
|
|
||||||
|
print("Training complete. Models and logs are saved under respective directories for each loss function.")
|
||||||