Plotted controller max normalzied across epoch. Also training average normalized
BIN
analysis/average_normalized/IC_-3.14_0.0_0.0_0.0/constant.png
Normal file
|
After Width: | Height: | Size: 972 KiB |
BIN
analysis/average_normalized/IC_-3.14_0.0_0.0_0.0/exponential.png
Normal file
|
After Width: | Height: | Size: 2.2 MiB |
BIN
analysis/average_normalized/IC_-3.14_0.0_0.0_0.0/linear.png
Normal file
|
After Width: | Height: | Size: 945 KiB |
BIN
analysis/average_normalized/IC_-3.14_0.0_0.0_0.0/quadratic.png
Normal file
|
After Width: | Height: | Size: 2.5 MiB |
|
Before Width: | Height: | Size: 4.8 MiB |
|
Before Width: | Height: | Size: 8.0 MiB |
@ -9,6 +9,13 @@ from multiprocessing import Pool, cpu_count
|
||||
# Define PendulumController class
|
||||
from PendulumController import PendulumController
|
||||
|
||||
# Constants
|
||||
g = 9.81 # Gravity
|
||||
R = 1.0 # Length of the pendulum
|
||||
m = 10.0 # Mass
|
||||
dt = 0.02 # Time step
|
||||
num_steps = 500 # Simulation time steps
|
||||
|
||||
# ODE solver (RK4 method)
|
||||
def pendulum_ode_step(state, dt, desired_theta, controller):
|
||||
theta, omega, alpha = state
|
||||
@ -44,41 +51,9 @@ def pendulum_ode_step(state, dt, desired_theta, controller):
|
||||
new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6.0
|
||||
return new_state
|
||||
|
||||
# Constants
|
||||
g = 9.81 # Gravity
|
||||
R = 1.0 # Length of the pendulum
|
||||
m = 10.0 # Mass
|
||||
dt = 0.02 # Time step
|
||||
num_steps = 500 # Simulation time steps
|
||||
|
||||
# Directory containing controller files
|
||||
loss_function = "quadratic"
|
||||
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/training/{loss_function}/controllers"
|
||||
#controller_dir = f"C:/Users/Judson/Desktop/New Gitea/Neural-Networks-in-GNC/inverted_pendulum/training/{loss_function}/controllers"
|
||||
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
|
||||
|
||||
# Sorting controllers by epoch
|
||||
controller_epochs = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
|
||||
sorted_controllers = [x for _, x in sorted(zip(controller_epochs, controller_files))]
|
||||
|
||||
# **Epoch Range Selection**
|
||||
epoch_range = (0, 1000) # Set your desired range (e.g., (0, 5000) or (0, 100))
|
||||
|
||||
filtered_controllers = [
|
||||
f for f in sorted_controllers
|
||||
if epoch_range[0] <= int(f.split('_')[1].split('.')[0]) <= epoch_range[1]
|
||||
]
|
||||
|
||||
# **Granularity Control: Select every Nth controller**
|
||||
N = 1 # Change this value to adjust granularity (e.g., every 5th controller)
|
||||
selected_controllers = filtered_controllers[::N] # Take every Nth controller within the range
|
||||
|
||||
# Initial condition
|
||||
# theta0, omega0, alpha0, desired_theta = (-np.pi, -2*np.pi, 0.0, -1.3*np.pi) # Example initial condition
|
||||
theta0, omega0, alpha0, desired_theta = (-np.pi, 0.0, 0.0, 0.0) # Example initial condition
|
||||
|
||||
# Parallel function must return epoch explicitly
|
||||
def run_simulation(controller_file):
|
||||
def run_simulation(params):
|
||||
controller_file, initial_condition = params
|
||||
theta0, omega0, alpha0, desired_theta = initial_condition
|
||||
epoch = int(controller_file.split('_')[1].split('.')[0])
|
||||
|
||||
# Load controller
|
||||
@ -96,54 +71,100 @@ def run_simulation(controller_file):
|
||||
|
||||
return epoch, theta_vals # Return epoch with data
|
||||
|
||||
# Parallel processing
|
||||
# Named initial conditions
|
||||
initial_conditions = {
|
||||
"small_perturbation": (0.1*np.pi, 0.0, 0.0, 0.0),
|
||||
"large_perturbation": (-np.pi, 0.0, 0.0, 0),
|
||||
"overshoot_vertical_test": (-0.1*np.pi, 2*np.pi, 0.0, 0.0),
|
||||
"overshoot_angle_test": (0.2*np.pi, 2*np.pi, 0.0, 0.3*np.pi),
|
||||
"extreme_perturbation": (4*np.pi, 0.0, 0.0, 0),
|
||||
}
|
||||
|
||||
# Loss functions to iterate over
|
||||
loss_functions = ["constant", "linear", "quadratic", "exponential", "inverse", "inverse_squared"]
|
||||
|
||||
|
||||
epoch_start = 0 # Start of the epoch range
|
||||
epoch_end = 500 # End of the epoch range
|
||||
epoch_step = 5 # Interval between epochs
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
|
||||
print(f"Using {num_workers} parallel workers...")
|
||||
|
||||
with Pool(processes=num_workers) as pool:
|
||||
results = pool.map(run_simulation, selected_controllers)
|
||||
for condition_name, initial_condition in initial_conditions.items():
|
||||
full_path = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/analysis/max_normalized/{condition_name}"
|
||||
os.makedirs(full_path, exist_ok=True) # Create directory if it does not exist
|
||||
|
||||
for loss_function in loss_functions:
|
||||
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/max_normalized/{loss_function}/controllers"
|
||||
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
|
||||
# Extract epoch numbers and filter based on the defined range and interval
|
||||
epoch_numbers = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
|
||||
selected_epochs = [e for e in epoch_numbers if epoch_start <= e <= epoch_end and (e - epoch_start) % epoch_step == 0]
|
||||
|
||||
# Sort results by epoch to ensure correct order
|
||||
results.sort(key=lambda x: x[0])
|
||||
epochs, theta_over_epochs = zip(*results) # Unzip sorted results
|
||||
# Filter the controller files to include only those within the selected epochs
|
||||
selected_controllers = [f for f in controller_files if int(f.split('_')[1].split('.')[0]) in selected_epochs]
|
||||
selected_controllers.sort(key=lambda f: int(f.split('_')[1].split('.')[0]))
|
||||
|
||||
# Convert results to NumPy arrays
|
||||
theta_over_epochs = np.array(theta_over_epochs)
|
||||
# Setup parallel processing
|
||||
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
|
||||
print(f"Using {num_workers} parallel workers for {loss_function} with initial condition {condition_name}...")
|
||||
|
||||
# Create 3D line plot
|
||||
fig = plt.figure(figsize=(10, 7))
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
with Pool(processes=num_workers) as pool:
|
||||
params = [(controller_file, initial_condition) for controller_file in selected_controllers]
|
||||
results = pool.map(run_simulation, params)
|
||||
|
||||
time_steps = np.arange(num_steps) * dt # X-axis (time)
|
||||
results.sort(key=lambda x: x[0])
|
||||
epochs, theta_over_epochs = zip(*results)
|
||||
|
||||
# Plot each controller as a separate line
|
||||
for epoch, theta_vals in zip(epochs, theta_over_epochs):
|
||||
ax.plot(
|
||||
[epoch] * len(time_steps), # Y-axis (epoch stays constant for each line)
|
||||
time_steps, # X-axis (time)
|
||||
theta_vals, # Z-axis (theta evolution)
|
||||
label=f"Epoch {epoch}" if epoch % (N * 10) == 0 else "", # Label some lines for clarity
|
||||
)
|
||||
fig = plt.figure(figsize=(7, 5))
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
time_steps = np.arange(num_steps) * dt
|
||||
|
||||
# Labels
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Time (s)")
|
||||
ax.set_zlabel("Theta (rad)")
|
||||
ax.set_title(f"Pendulum Angle Evolution for {loss_function}")
|
||||
# Plot the epochs in reverse order because we view it where epoch 0 is in front
|
||||
for epoch, theta_vals in reversed(list(zip(epochs, theta_over_epochs))):
|
||||
ax.plot([epoch] * len(time_steps), time_steps, theta_vals)
|
||||
|
||||
# Add a horizontal line at desired_theta across all epochs and time steps
|
||||
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
|
||||
ax.plot(
|
||||
epochs_array, # X-axis spanning all epochs
|
||||
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
|
||||
[desired_theta] * len(epochs_array), # Constant Z-axis value of desired_theta
|
||||
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
|
||||
)
|
||||
|
||||
# Improve visibility
|
||||
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
|
||||
# Add a horizontal line at desired_theta across all epochs and time steps
|
||||
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
|
||||
desired_theta = initial_condition[-1]
|
||||
ax.plot(
|
||||
epochs_array, # X-axis spanning all epochs
|
||||
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
|
||||
[desired_theta] * len(epochs_array), # Constant Z-axis value of desired_theta
|
||||
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
|
||||
)
|
||||
|
||||
plt.savefig(f"{loss_function}.png", dpi=600)
|
||||
#plt.show()
|
||||
print(f"Saved plot as '{loss_function}.png'.")
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Time (s)")
|
||||
ax.set_zlabel("Theta (rad)")
|
||||
condition_text = f"IC_{'_'.join(map(lambda x: str(round(x, 2)), initial_condition))}"
|
||||
ax.set_title(f"Pendulum Angle Evolution for {loss_function} and {condition_text}")
|
||||
|
||||
# Calculate the range of theta values across all epochs
|
||||
theta_values = np.concatenate(theta_over_epochs)
|
||||
theta_min = np.min(theta_values)
|
||||
theta_max = np.max(theta_values)
|
||||
|
||||
# Determine the desired range around the desired_theta
|
||||
desired_range_min = desired_theta - 1 * np.pi
|
||||
desired_range_max = desired_theta + 1 * np.pi
|
||||
|
||||
# Check if current theta values fall outside the desired range
|
||||
if theta_min < desired_range_min:
|
||||
desired_range_min = desired_range_min
|
||||
else:
|
||||
desired_range_min = theta_min
|
||||
|
||||
if theta_max > desired_range_max:
|
||||
desired_range_max = desired_range_max
|
||||
else:
|
||||
desired_range_max = theta_max
|
||||
|
||||
ax.set_zlim(desired_range_min, desired_range_max)
|
||||
|
||||
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
|
||||
|
||||
plot_filename = os.path.join(full_path, f"{loss_function}.png")
|
||||
plt.savefig(plot_filename, dpi=300)
|
||||
plt.close()
|
||||
print(f"Saved plot as '{plot_filename}'.")
|
||||
149
analysis/controller_across_epochs_old.py
Normal file
@ -0,0 +1,149 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
# Define PendulumController class
|
||||
from PendulumController import PendulumController
|
||||
|
||||
# ODE solver (RK4 method)
|
||||
def pendulum_ode_step(state, dt, desired_theta, controller):
|
||||
theta, omega, alpha = state
|
||||
|
||||
def compute_torque(th, om, al):
|
||||
inp = torch.tensor([[th, om, al, desired_theta]], dtype=torch.float32)
|
||||
with torch.no_grad():
|
||||
torque = controller(inp)
|
||||
torque = torch.clamp(torque, -250, 250)
|
||||
return float(torque)
|
||||
|
||||
def derivatives(state, torque):
|
||||
th, om, al = state
|
||||
a = (g / R) * np.sin(th) + torque / (m * R**2)
|
||||
return np.array([om, a, 0]) # dtheta, domega, dalpha
|
||||
|
||||
# Compute RK4 steps
|
||||
torque1 = compute_torque(theta, omega, alpha)
|
||||
k1 = dt * derivatives(state, torque1)
|
||||
|
||||
state_k2 = state + 0.5 * k1
|
||||
torque2 = compute_torque(state_k2[0], state_k2[1], state_k2[2])
|
||||
k2 = dt * derivatives(state_k2, torque2)
|
||||
|
||||
state_k3 = state + 0.5 * k2
|
||||
torque3 = compute_torque(state_k3[0], state_k3[1], state_k3[2])
|
||||
k3 = dt * derivatives(state_k3, torque3)
|
||||
|
||||
state_k4 = state + k3
|
||||
torque4 = compute_torque(state_k4[0], state_k4[1], state_k4[2])
|
||||
k4 = dt * derivatives(state_k4, torque4)
|
||||
|
||||
new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6.0
|
||||
return new_state
|
||||
|
||||
# Constants
|
||||
g = 9.81 # Gravity
|
||||
R = 1.0 # Length of the pendulum
|
||||
m = 10.0 # Mass
|
||||
dt = 0.02 # Time step
|
||||
num_steps = 500 # Simulation time steps
|
||||
|
||||
# Directory containing controller files
|
||||
loss_function = "quadratic"
|
||||
controller_dir = f"/home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/normalized/training/{loss_function}/controllers"
|
||||
#controller_dir = f"C:/Users/Judson/Desktop/New Gitea/Neural-Networks-in-GNC/inverted_pendulum/training/{loss_function}/controllers"
|
||||
controller_files = sorted([f for f in os.listdir(controller_dir) if f.startswith("controller_") and f.endswith(".pth")])
|
||||
|
||||
# Sorting controllers by epoch
|
||||
controller_epochs = [int(f.split('_')[1].split('.')[0]) for f in controller_files]
|
||||
sorted_controllers = [x for _, x in sorted(zip(controller_epochs, controller_files))]
|
||||
|
||||
# **Epoch Range Selection**
|
||||
epoch_range = (0, 1000) # Set your desired range (e.g., (0, 5000) or (0, 100))
|
||||
|
||||
filtered_controllers = [
|
||||
f for f in sorted_controllers
|
||||
if epoch_range[0] <= int(f.split('_')[1].split('.')[0]) <= epoch_range[1]
|
||||
]
|
||||
|
||||
# **Granularity Control: Select every Nth controller**
|
||||
N = 1 # Change this value to adjust granularity (e.g., every 5th controller)
|
||||
selected_controllers = filtered_controllers[::N] # Take every Nth controller within the range
|
||||
|
||||
# Initial condition
|
||||
# theta0, omega0, alpha0, desired_theta = (-np.pi, -2*np.pi, 0.0, -1.3*np.pi) # Example initial condition
|
||||
theta0, omega0, alpha0, desired_theta = (-np.pi, 0.0, 0.0, 0.0) # Example initial condition
|
||||
|
||||
# Parallel function must return epoch explicitly
|
||||
def run_simulation(controller_file):
|
||||
epoch = int(controller_file.split('_')[1].split('.')[0])
|
||||
|
||||
# Load controller
|
||||
controller = PendulumController()
|
||||
controller.load_state_dict(torch.load(os.path.join(controller_dir, controller_file)))
|
||||
controller.eval()
|
||||
|
||||
# Run simulation
|
||||
state = np.array([theta0, omega0, alpha0])
|
||||
theta_vals = []
|
||||
|
||||
for _ in range(num_steps):
|
||||
theta_vals.append(state[0])
|
||||
state = pendulum_ode_step(state, dt, desired_theta, controller)
|
||||
|
||||
return epoch, theta_vals # Return epoch with data
|
||||
|
||||
# Parallel processing
|
||||
if __name__ == "__main__":
|
||||
num_workers = min(cpu_count(), 16) # Limit to 16 workers max
|
||||
print(f"Using {num_workers} parallel workers...")
|
||||
|
||||
with Pool(processes=num_workers) as pool:
|
||||
results = pool.map(run_simulation, selected_controllers)
|
||||
|
||||
# Sort results by epoch to ensure correct order
|
||||
results.sort(key=lambda x: x[0])
|
||||
epochs, theta_over_epochs = zip(*results) # Unzip sorted results
|
||||
|
||||
# Convert results to NumPy arrays
|
||||
theta_over_epochs = np.array(theta_over_epochs)
|
||||
|
||||
# Create 3D line plot
|
||||
fig = plt.figure(figsize=(10, 7))
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
time_steps = np.arange(num_steps) * dt # X-axis (time)
|
||||
|
||||
# Plot each controller as a separate line
|
||||
for epoch, theta_vals in zip(epochs, theta_over_epochs):
|
||||
ax.plot(
|
||||
[epoch] * len(time_steps), # Y-axis (epoch stays constant for each line)
|
||||
time_steps, # X-axis (time)
|
||||
theta_vals, # Z-axis (theta evolution)
|
||||
label=f"Epoch {epoch}" if epoch % (N * 10) == 0 else "", # Label some lines for clarity
|
||||
)
|
||||
|
||||
# Labels
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Time (s)")
|
||||
ax.set_zlabel("Theta (rad)")
|
||||
ax.set_title(f"Pendulum Angle Evolution for {loss_function}")
|
||||
|
||||
# Add a horizontal line at desired_theta across all epochs and time steps
|
||||
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
|
||||
ax.plot(
|
||||
epochs_array, # X-axis spanning all epochs
|
||||
[time_steps.max()] * len(epochs_array), # Y-axis at the maximum time step
|
||||
[desired_theta] * len(epochs_array), # Constant Z-axis value of desired_theta
|
||||
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time'
|
||||
)
|
||||
|
||||
# Improve visibility
|
||||
ax.view_init(elev=20, azim=-135) # Adjust 3D perspective
|
||||
|
||||
plt.savefig(f"{loss_function}.png", dpi=600)
|
||||
#plt.show()
|
||||
print(f"Saved plot as '{loss_function}.png'.")
|
||||
|
Before Width: | Height: | Size: 9.4 MiB |
|
Before Width: | Height: | Size: 8.7 MiB |
|
Before Width: | Height: | Size: 8.3 MiB |
|
Before Width: | Height: | Size: 7.5 MiB |
BIN
analysis/max_normalized/extreme_perturbation/constant.png
Normal file
|
After Width: | Height: | Size: 564 KiB |
BIN
analysis/max_normalized/extreme_perturbation/exponential.png
Normal file
|
After Width: | Height: | Size: 561 KiB |
BIN
analysis/max_normalized/extreme_perturbation/inverse.png
Normal file
|
After Width: | Height: | Size: 524 KiB |
BIN
analysis/max_normalized/extreme_perturbation/inverse_squared.png
Normal file
|
After Width: | Height: | Size: 534 KiB |
BIN
analysis/max_normalized/extreme_perturbation/linear.png
Normal file
|
After Width: | Height: | Size: 546 KiB |
BIN
analysis/max_normalized/extreme_perturbation/quadratic.png
Normal file
|
After Width: | Height: | Size: 611 KiB |
BIN
analysis/max_normalized/large_perturbation/constant.png
Normal file
|
After Width: | Height: | Size: 733 KiB |
BIN
analysis/max_normalized/large_perturbation/exponential.png
Normal file
|
After Width: | Height: | Size: 832 KiB |
BIN
analysis/max_normalized/large_perturbation/inverse.png
Normal file
|
After Width: | Height: | Size: 794 KiB |
BIN
analysis/max_normalized/large_perturbation/inverse_squared.png
Normal file
|
After Width: | Height: | Size: 752 KiB |
BIN
analysis/max_normalized/large_perturbation/linear.png
Normal file
|
After Width: | Height: | Size: 707 KiB |
BIN
analysis/max_normalized/large_perturbation/quadratic.png
Normal file
|
After Width: | Height: | Size: 814 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/constant.png
Normal file
|
After Width: | Height: | Size: 494 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/exponential.png
Normal file
|
After Width: | Height: | Size: 578 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/inverse.png
Normal file
|
After Width: | Height: | Size: 505 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/inverse_squared.png
Normal file
|
After Width: | Height: | Size: 502 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/linear.png
Normal file
|
After Width: | Height: | Size: 518 KiB |
BIN
analysis/max_normalized/overshoot_angle_test/quadratic.png
Normal file
|
After Width: | Height: | Size: 521 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/constant.png
Normal file
|
After Width: | Height: | Size: 517 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/exponential.png
Normal file
|
After Width: | Height: | Size: 538 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/inverse.png
Normal file
|
After Width: | Height: | Size: 508 KiB |
|
After Width: | Height: | Size: 526 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/linear.png
Normal file
|
After Width: | Height: | Size: 514 KiB |
BIN
analysis/max_normalized/overshoot_vertical_test/quadratic.png
Normal file
|
After Width: | Height: | Size: 559 KiB |
BIN
analysis/max_normalized/small_perturbation/constant.png
Normal file
|
After Width: | Height: | Size: 513 KiB |
BIN
analysis/max_normalized/small_perturbation/exponential.png
Normal file
|
After Width: | Height: | Size: 474 KiB |
BIN
analysis/max_normalized/small_perturbation/inverse.png
Normal file
|
After Width: | Height: | Size: 489 KiB |
BIN
analysis/max_normalized/small_perturbation/inverse_squared.png
Normal file
|
After Width: | Height: | Size: 498 KiB |
BIN
analysis/max_normalized/small_perturbation/linear.png
Normal file
|
After Width: | Height: | Size: 497 KiB |
BIN
analysis/max_normalized/small_perturbation/quadratic.png
Normal file
|
After Width: | Height: | Size: 503 KiB |
|
Before Width: | Height: | Size: 9.8 MiB |
17
training/normalized/PendulumController.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class PendulumController(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(4, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
raw_torque = self.net(x)
|
||||
return torch.clamp(raw_torque, -250, 250)
|
||||
26
training/normalized/PendulumDynamics.py
Normal file
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class PendulumDynamics(nn.Module):
|
||||
def __init__(self, controller, m:'float'=1, R:'float'=1, g:'float'=9.81):
|
||||
super().__init__()
|
||||
self.controller = controller
|
||||
self.m: 'float' = m
|
||||
self.R: 'float' = R
|
||||
self.g: 'float' = g
|
||||
|
||||
def forward(self, t, state):
|
||||
# Get the current values from the state
|
||||
theta, omega, alpha, desired_theta = state[:, 0], state[:, 1], state[:, 2], state[:, 3]
|
||||
|
||||
# Make the input stack for the controller
|
||||
input = torch.stack([theta, omega, alpha, desired_theta], dim=1)
|
||||
|
||||
# Get the torque (the output of the neural network)
|
||||
tau = self.controller(input).squeeze(-1)
|
||||
|
||||
# Relax alpha
|
||||
alpha_desired = (self.g / self.R) * torch.sin(theta) + tau / (self.m * self.R**2)
|
||||
dalpha = alpha_desired - alpha
|
||||
|
||||
return torch.stack([omega, alpha, dalpha, torch.zeros_like(desired_theta)], dim=1)
|
||||
BIN
training/normalized/__pycache__/PendulumDynamics.cpython-310.pyc
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1002
training/normalized/average_normalized/constant/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1002
training/normalized/average_normalized/exponential/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1002
training/normalized/average_normalized/inverse/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
@ -0,0 +1,608 @@
|
||||
Epoch,Loss
|
||||
0,733.2171630859375
|
||||
1,747.2186279296875
|
||||
2,173.32571411132812
|
||||
3,113.02052307128906
|
||||
4,52.3023796081543
|
||||
5,29.964460372924805
|
||||
6,22.173477172851562
|
||||
7,17.20752716064453
|
||||
8,14.393138885498047
|
||||
9,13.495450019836426
|
||||
10,12.72439193725586
|
||||
11,12.004809379577637
|
||||
12,11.283649444580078
|
||||
13,10.60710620880127
|
||||
14,9.974536895751953
|
||||
15,9.39864730834961
|
||||
16,8.96867847442627
|
||||
17,8.593585968017578
|
||||
18,8.431589126586914
|
||||
19,8.36359977722168
|
||||
20,8.257978439331055
|
||||
21,8.131260871887207
|
||||
22,7.987368583679199
|
||||
23,7.828018665313721
|
||||
24,7.654465198516846
|
||||
25,7.468804836273193
|
||||
26,7.282363414764404
|
||||
27,7.1038055419921875
|
||||
28,6.926774024963379
|
||||
29,6.756581783294678
|
||||
30,6.605012893676758
|
||||
31,6.480014324188232
|
||||
32,6.387184143066406
|
||||
33,6.323246955871582
|
||||
34,6.280386924743652
|
||||
35,6.2463226318359375
|
||||
36,6.210468769073486
|
||||
37,6.165651798248291
|
||||
38,6.10935640335083
|
||||
39,6.044148921966553
|
||||
40,5.975587844848633
|
||||
41,5.909937381744385
|
||||
42,5.85064697265625
|
||||
43,5.797428131103516
|
||||
44,5.749576091766357
|
||||
45,5.7068305015563965
|
||||
46,5.668132305145264
|
||||
47,5.631911754608154
|
||||
48,5.5969085693359375
|
||||
49,5.562211513519287
|
||||
50,5.527288436889648
|
||||
51,5.491824626922607
|
||||
52,5.455735206604004
|
||||
53,5.4193949699401855
|
||||
54,5.383573055267334
|
||||
55,5.350064754486084
|
||||
56,5.319507598876953
|
||||
57,5.289746284484863
|
||||
58,5.259618759155273
|
||||
59,5.228979110717773
|
||||
60,5.198094844818115
|
||||
61,5.167768955230713
|
||||
62,5.1394476890563965
|
||||
63,5.113860607147217
|
||||
64,5.0905303955078125
|
||||
65,5.070118427276611
|
||||
66,5.05070686340332
|
||||
67,5.031530857086182
|
||||
68,5.012385845184326
|
||||
69,4.993386268615723
|
||||
70,4.974751949310303
|
||||
71,4.956557273864746
|
||||
72,4.939281940460205
|
||||
73,4.923177719116211
|
||||
74,4.9080328941345215
|
||||
75,4.893461227416992
|
||||
76,4.879059791564941
|
||||
77,4.86448860168457
|
||||
78,4.8493499755859375
|
||||
79,4.833917617797852
|
||||
80,4.818392753601074
|
||||
81,4.80300760269165
|
||||
82,4.788031101226807
|
||||
83,4.773463249206543
|
||||
84,4.759511947631836
|
||||
85,4.746511459350586
|
||||
86,4.734636306762695
|
||||
87,4.72379207611084
|
||||
88,4.713471412658691
|
||||
89,4.7037153244018555
|
||||
90,4.694935321807861
|
||||
91,4.687223434448242
|
||||
92,4.680534362792969
|
||||
93,4.674868106842041
|
||||
94,4.670145511627197
|
||||
95,4.666247367858887
|
||||
96,4.6630682945251465
|
||||
97,4.660345554351807
|
||||
98,4.6577630043029785
|
||||
99,4.655178070068359
|
||||
100,4.65253210067749
|
||||
101,4.649806499481201
|
||||
102,4.64708137512207
|
||||
103,4.644402980804443
|
||||
104,4.641839504241943
|
||||
105,4.639404296875
|
||||
106,4.637064456939697
|
||||
107,4.634809970855713
|
||||
108,4.632608890533447
|
||||
109,4.63045597076416
|
||||
110,4.6283416748046875
|
||||
111,4.626272201538086
|
||||
112,4.624259948730469
|
||||
113,4.622305870056152
|
||||
114,4.620428085327148
|
||||
115,4.618651866912842
|
||||
116,4.616915702819824
|
||||
117,4.615218162536621
|
||||
118,4.613545894622803
|
||||
119,4.6118950843811035
|
||||
120,4.610260963439941
|
||||
121,4.608641147613525
|
||||
122,4.607039928436279
|
||||
123,4.6054606437683105
|
||||
124,4.603922367095947
|
||||
125,4.602435111999512
|
||||
126,4.601007461547852
|
||||
127,4.599658966064453
|
||||
128,4.598391532897949
|
||||
129,4.59716796875
|
||||
130,4.595986843109131
|
||||
131,4.594844341278076
|
||||
132,4.593727111816406
|
||||
133,4.592639446258545
|
||||
134,4.591578960418701
|
||||
135,4.590541839599609
|
||||
136,4.589527130126953
|
||||
137,4.588528633117676
|
||||
138,4.587544918060303
|
||||
139,4.586572647094727
|
||||
140,4.585610389709473
|
||||
141,4.58465576171875
|
||||
142,4.583714008331299
|
||||
143,4.582784652709961
|
||||
144,4.58186149597168
|
||||
145,4.580949783325195
|
||||
146,4.58004903793335
|
||||
147,4.579157829284668
|
||||
148,4.578279972076416
|
||||
149,4.5774149894714355
|
||||
150,4.5765604972839355
|
||||
151,4.575720310211182
|
||||
152,4.574892044067383
|
||||
153,4.574074745178223
|
||||
154,4.573271751403809
|
||||
155,4.57248067855835
|
||||
156,4.571700096130371
|
||||
157,4.570932388305664
|
||||
158,4.570174694061279
|
||||
159,4.569428443908691
|
||||
160,4.568695545196533
|
||||
161,4.567972183227539
|
||||
162,4.567262649536133
|
||||
163,4.566561698913574
|
||||
164,4.565873622894287
|
||||
165,4.565197944641113
|
||||
166,4.564536094665527
|
||||
167,4.563884258270264
|
||||
168,4.563241958618164
|
||||
169,4.562610626220703
|
||||
170,4.561987400054932
|
||||
171,4.561375141143799
|
||||
172,4.56077241897583
|
||||
173,4.560176849365234
|
||||
174,4.559585094451904
|
||||
175,4.558999061584473
|
||||
176,4.558421611785889
|
||||
177,4.557851791381836
|
||||
178,4.55728816986084
|
||||
179,4.556730270385742
|
||||
180,4.556179523468018
|
||||
181,4.555635452270508
|
||||
182,4.555098533630371
|
||||
183,4.554568767547607
|
||||
184,4.554043769836426
|
||||
185,4.553525447845459
|
||||
186,4.553014755249023
|
||||
187,4.552509307861328
|
||||
188,4.552009105682373
|
||||
189,4.551515102386475
|
||||
190,4.551027297973633
|
||||
191,4.550544738769531
|
||||
192,4.55006742477417
|
||||
193,4.549595832824707
|
||||
194,4.549131393432617
|
||||
195,4.5486741065979
|
||||
196,4.548222064971924
|
||||
197,4.547779083251953
|
||||
198,4.547341823577881
|
||||
199,4.546910762786865
|
||||
200,4.546485424041748
|
||||
201,4.546065330505371
|
||||
202,4.545650959014893
|
||||
203,4.545242786407471
|
||||
204,4.544842720031738
|
||||
205,4.5444488525390625
|
||||
206,4.544060230255127
|
||||
207,4.54367733001709
|
||||
208,4.543309211730957
|
||||
209,4.5429511070251465
|
||||
210,4.542597770690918
|
||||
211,4.542244911193848
|
||||
212,4.541895389556885
|
||||
213,4.541547775268555
|
||||
214,4.541203498840332
|
||||
215,4.540863513946533
|
||||
216,4.540529727935791
|
||||
217,4.5402045249938965
|
||||
218,4.539884090423584
|
||||
219,4.539568901062012
|
||||
220,4.539254665374756
|
||||
221,4.538944721221924
|
||||
222,4.538638591766357
|
||||
223,4.538336277008057
|
||||
224,4.538040637969971
|
||||
225,4.537751197814941
|
||||
226,4.537464618682861
|
||||
227,4.537179470062256
|
||||
228,4.536898612976074
|
||||
229,4.536619186401367
|
||||
230,4.536343574523926
|
||||
231,4.536070823669434
|
||||
232,4.535802364349365
|
||||
233,4.535538196563721
|
||||
234,4.535276889801025
|
||||
235,4.535017967224121
|
||||
236,4.534761428833008
|
||||
237,4.534510135650635
|
||||
238,4.5342607498168945
|
||||
239,4.534013271331787
|
||||
240,4.533768653869629
|
||||
241,4.533526420593262
|
||||
242,4.533287525177002
|
||||
243,4.533048629760742
|
||||
244,4.532813549041748
|
||||
245,4.532580375671387
|
||||
246,4.532350063323975
|
||||
247,4.5321245193481445
|
||||
248,4.531898021697998
|
||||
249,4.531674861907959
|
||||
250,4.531455039978027
|
||||
251,4.531236171722412
|
||||
252,4.531019687652588
|
||||
253,4.530806064605713
|
||||
254,4.530594348907471
|
||||
255,4.530384540557861
|
||||
256,4.530177593231201
|
||||
257,4.529972076416016
|
||||
258,4.529767990112305
|
||||
259,4.529566287994385
|
||||
260,4.529365539550781
|
||||
261,4.529166221618652
|
||||
262,4.5289692878723145
|
||||
263,4.528773784637451
|
||||
264,4.5285797119140625
|
||||
265,4.52838659286499
|
||||
266,4.528195858001709
|
||||
267,4.528006553649902
|
||||
268,4.52781867980957
|
||||
269,4.527633190155029
|
||||
270,4.527448654174805
|
||||
271,4.527266502380371
|
||||
272,4.527083396911621
|
||||
273,4.52690315246582
|
||||
274,4.526723861694336
|
||||
275,4.526546001434326
|
||||
276,4.526369571685791
|
||||
277,4.526192665100098
|
||||
278,4.526019096374512
|
||||
279,4.525848388671875
|
||||
280,4.525679588317871
|
||||
281,4.5255126953125
|
||||
282,4.5253472328186035
|
||||
283,4.525182723999023
|
||||
284,4.525018692016602
|
||||
285,4.524855613708496
|
||||
286,4.524693489074707
|
||||
287,4.524532794952393
|
||||
288,4.524373531341553
|
||||
289,4.524214267730713
|
||||
290,4.5240559577941895
|
||||
291,4.523898601531982
|
||||
292,4.523742198944092
|
||||
293,4.523586273193359
|
||||
294,4.523431301116943
|
||||
295,4.52327823638916
|
||||
296,4.523126602172852
|
||||
297,4.522974491119385
|
||||
298,4.522823333740234
|
||||
299,4.5226731300354
|
||||
300,4.522523403167725
|
||||
301,4.522375106811523
|
||||
302,4.522228240966797
|
||||
303,4.522083282470703
|
||||
304,4.521938800811768
|
||||
305,4.521796703338623
|
||||
306,4.521655559539795
|
||||
307,4.521514415740967
|
||||
308,4.521374225616455
|
||||
309,4.521234035491943
|
||||
310,4.521094799041748
|
||||
311,4.520956993103027
|
||||
312,4.520818710327148
|
||||
313,4.520682334899902
|
||||
314,4.520545959472656
|
||||
315,4.520410060882568
|
||||
316,4.520275115966797
|
||||
317,4.5201416015625
|
||||
318,4.520009517669678
|
||||
319,4.519878387451172
|
||||
320,4.519747734069824
|
||||
321,4.519618034362793
|
||||
322,4.51948881149292
|
||||
323,4.519359588623047
|
||||
324,4.519231796264648
|
||||
325,4.519104957580566
|
||||
326,4.518978118896484
|
||||
327,4.518850803375244
|
||||
328,4.5187249183654785
|
||||
329,4.518599510192871
|
||||
330,4.518474578857422
|
||||
331,4.518350601196289
|
||||
332,4.518225193023682
|
||||
333,4.518101215362549
|
||||
334,4.517977237701416
|
||||
335,4.517852783203125
|
||||
336,4.51772928237915
|
||||
337,4.517606258392334
|
||||
338,4.517481803894043
|
||||
339,4.517359733581543
|
||||
340,4.517236709594727
|
||||
341,4.517114162445068
|
||||
342,4.51699161529541
|
||||
343,4.516868591308594
|
||||
344,4.516744613647461
|
||||
345,4.516623020172119
|
||||
346,4.516499996185303
|
||||
347,4.5163774490356445
|
||||
348,4.516254901885986
|
||||
349,4.5161333084106445
|
||||
350,4.5160112380981445
|
||||
351,4.5158891677856445
|
||||
352,4.515769004821777
|
||||
353,4.5156474113464355
|
||||
354,4.51552677154541
|
||||
355,4.515407085418701
|
||||
356,4.515288829803467
|
||||
357,4.515170574188232
|
||||
358,4.515053749084473
|
||||
359,4.5149383544921875
|
||||
360,4.514822483062744
|
||||
361,4.514708042144775
|
||||
362,4.5145955085754395
|
||||
363,4.514482498168945
|
||||
364,4.514371395111084
|
||||
365,4.5142598152160645
|
||||
366,4.5141496658325195
|
||||
367,4.514039993286133
|
||||
368,4.513929843902588
|
||||
369,4.513820171356201
|
||||
370,4.513709545135498
|
||||
371,4.513599872589111
|
||||
372,4.51348876953125
|
||||
373,4.5133771896362305
|
||||
374,4.513265132904053
|
||||
375,4.513152122497559
|
||||
376,4.51303768157959
|
||||
377,4.512922763824463
|
||||
378,4.512807846069336
|
||||
379,4.512691974639893
|
||||
380,4.512577056884766
|
||||
381,4.512462139129639
|
||||
382,4.512347221374512
|
||||
383,4.512232780456543
|
||||
384,4.512118339538574
|
||||
385,4.5120038986206055
|
||||
386,4.511888027191162
|
||||
387,4.511774063110352
|
||||
388,4.511661052703857
|
||||
389,4.5115485191345215
|
||||
390,4.511435508728027
|
||||
391,4.511322975158691
|
||||
392,4.5112104415893555
|
||||
393,4.511097431182861
|
||||
394,4.510985851287842
|
||||
395,4.510873317718506
|
||||
396,4.510761260986328
|
||||
397,4.510649681091309
|
||||
398,4.510537624359131
|
||||
399,4.510426998138428
|
||||
400,4.510315895080566
|
||||
401,4.510205268859863
|
||||
402,4.510096073150635
|
||||
403,4.509986400604248
|
||||
404,4.5098772048950195
|
||||
405,4.509767532348633
|
||||
406,4.509658336639404
|
||||
407,4.509550094604492
|
||||
408,4.509440898895264
|
||||
409,4.509332180023193
|
||||
410,4.509223937988281
|
||||
411,4.509115219116211
|
||||
412,4.509006500244141
|
||||
413,4.508897304534912
|
||||
414,4.508789539337158
|
||||
415,4.508681774139404
|
||||
416,4.508574485778809
|
||||
417,4.508468151092529
|
||||
418,4.508361339569092
|
||||
419,4.5082550048828125
|
||||
420,4.508149147033691
|
||||
421,4.5080437660217285
|
||||
422,4.507938385009766
|
||||
423,4.5078325271606445
|
||||
424,4.507728099822998
|
||||
425,4.507623672485352
|
||||
426,4.507519245147705
|
||||
427,4.507416725158691
|
||||
428,4.5073137283325195
|
||||
429,4.507210731506348
|
||||
430,4.50710916519165
|
||||
431,4.507007122039795
|
||||
432,4.506906032562256
|
||||
433,4.506804466247559
|
||||
434,4.506704330444336
|
||||
435,4.50660514831543
|
||||
436,4.506505966186523
|
||||
437,4.506407260894775
|
||||
438,4.5063090324401855
|
||||
439,4.50621223449707
|
||||
440,4.506114959716797
|
||||
441,4.50601863861084
|
||||
442,4.505922794342041
|
||||
443,4.5058274269104
|
||||
444,4.50573205947876
|
||||
445,4.505638122558594
|
||||
446,4.505544185638428
|
||||
447,4.505451679229736
|
||||
448,4.505359172821045
|
||||
449,4.505267143249512
|
||||
450,4.505176067352295
|
||||
451,4.505085468292236
|
||||
452,4.504995822906494
|
||||
453,4.504907131195068
|
||||
454,4.504818439483643
|
||||
455,4.504730701446533
|
||||
456,4.504644393920898
|
||||
457,4.504558086395264
|
||||
458,4.504471778869629
|
||||
459,4.504387378692627
|
||||
460,4.504302024841309
|
||||
461,4.5042195320129395
|
||||
462,4.504136085510254
|
||||
463,4.504052639007568
|
||||
464,4.503970623016357
|
||||
465,4.503889083862305
|
||||
466,4.503807544708252
|
||||
467,4.503726005554199
|
||||
468,4.503644943237305
|
||||
469,4.503564357757568
|
||||
470,4.503483295440674
|
||||
471,4.503403186798096
|
||||
472,4.503323078155518
|
||||
473,4.503242015838623
|
||||
474,4.5031633377075195
|
||||
475,4.5030837059021
|
||||
476,4.50300407409668
|
||||
477,4.502924919128418
|
||||
478,4.502845764160156
|
||||
479,4.5027666091918945
|
||||
480,4.502687931060791
|
||||
481,4.5026092529296875
|
||||
482,4.502531051635742
|
||||
483,4.502452373504639
|
||||
484,4.502374172210693
|
||||
485,4.502295970916748
|
||||
486,4.502218246459961
|
||||
487,4.502140045166016
|
||||
488,4.50206184387207
|
||||
489,4.501983642578125
|
||||
490,4.501905918121338
|
||||
491,4.501828193664551
|
||||
492,4.5017499923706055
|
||||
493,4.501672744750977
|
||||
494,4.5015950202941895
|
||||
495,4.501518249511719
|
||||
496,4.501441955566406
|
||||
497,4.501365661621094
|
||||
498,4.5012898445129395
|
||||
499,4.501214981079102
|
||||
500,4.501140117645264
|
||||
501,4.501065254211426
|
||||
502,4.500990867614746
|
||||
503,4.500916004180908
|
||||
504,4.5008416175842285
|
||||
505,4.500767707824707
|
||||
506,4.5006937980651855
|
||||
507,4.500619411468506
|
||||
508,4.500545978546143
|
||||
509,4.500472068786621
|
||||
510,4.5003981590271
|
||||
511,4.500324249267578
|
||||
512,4.500250816345215
|
||||
513,4.500176906585693
|
||||
514,4.50010347366333
|
||||
515,4.500030517578125
|
||||
516,4.4999566078186035
|
||||
517,4.499882698059082
|
||||
518,4.4998087882995605
|
||||
519,4.499735355377197
|
||||
520,4.499661445617676
|
||||
521,4.499587535858154
|
||||
522,4.499513149261475
|
||||
523,4.499438762664795
|
||||
524,4.499364852905273
|
||||
525,4.499290466308594
|
||||
526,4.499216079711914
|
||||
527,4.499140739440918
|
||||
528,4.499065399169922
|
||||
529,4.498990058898926
|
||||
530,4.498913764953613
|
||||
531,4.498837947845459
|
||||
532,4.498762607574463
|
||||
533,4.49868631362915
|
||||
534,4.4986114501953125
|
||||
535,4.498536109924316
|
||||
536,4.498461723327637
|
||||
537,4.498386859893799
|
||||
538,4.498311996459961
|
||||
539,4.4982380867004395
|
||||
540,4.498164653778076
|
||||
541,4.498090744018555
|
||||
542,4.498019218444824
|
||||
543,4.497945308685303
|
||||
544,4.497872829437256
|
||||
545,4.497801303863525
|
||||
546,4.497729301452637
|
||||
547,4.49765682220459
|
||||
548,4.497586250305176
|
||||
549,4.49751615524292
|
||||
550,4.497445106506348
|
||||
551,4.49737548828125
|
||||
552,4.497305870056152
|
||||
553,4.49723482131958
|
||||
554,4.497164726257324
|
||||
555,4.497094631195068
|
||||
556,4.497025012969971
|
||||
557,4.496954441070557
|
||||
558,4.496883392333984
|
||||
559,4.496815204620361
|
||||
560,4.4967451095581055
|
||||
561,4.496673107147217
|
||||
562,4.496603488922119
|
||||
563,4.4965338706970215
|
||||
564,4.496462821960449
|
||||
565,4.496394157409668
|
||||
566,4.496323585510254
|
||||
567,4.496251583099365
|
||||
568,4.496183395385742
|
||||
569,4.49611234664917
|
||||
570,4.496041774749756
|
||||
571,4.4959716796875
|
||||
572,4.495903491973877
|
||||
573,4.4958343505859375
|
||||
574,4.49576473236084
|
||||
575,4.495695114135742
|
||||
576,4.4956278800964355
|
||||
577,4.4955573081970215
|
||||
578,4.495490074157715
|
||||
579,4.495420932769775
|
||||
580,4.495352745056152
|
||||
581,4.495282173156738
|
||||
582,4.495216369628906
|
||||
583,4.495147228240967
|
||||
584,4.495075225830078
|
||||
585,4.4950079917907715
|
||||
586,4.494940757751465
|
||||
587,4.494872093200684
|
||||
588,4.494802951812744
|
||||
589,4.494732856750488
|
||||
590,4.494668483734131
|
||||
591,4.494600296020508
|
||||
592,4.494530200958252
|
||||
593,4.494461536407471
|
||||
594,4.494396209716797
|
||||
595,4.49432897567749
|
||||
596,4.494260787963867
|
||||
597,4.494192123413086
|
||||
598,4.4941229820251465
|
||||
599,4.494059085845947
|
||||
600,4.493993759155273
|
||||
601,4.493923664093018
|
||||
602,4.493853569030762
|
||||
603,4.493789196014404
|
||||
604,4.493722915649414
|
||||
605,4.493655681610107
|
||||
606,4.493587970733643
|
||||
|
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1002
training/normalized/average_normalized/linear/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1002
training/normalized/average_normalized/quadratic/training_log.csv
Normal file
155
training/normalized/average_normalized_trainer.py
Normal file
@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torchdiffeq import odeint
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import csv
|
||||
import inspect
|
||||
|
||||
from PendulumController import PendulumController
|
||||
from PendulumDynamics import PendulumDynamics
|
||||
|
||||
# Device setup
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Initial conditions (theta0, omega0, alpha0, desired_theta)
|
||||
from initial_conditions import initial_conditions
|
||||
state_0 = torch.tensor(initial_conditions, dtype=torch.float32, device=device)
|
||||
|
||||
# Device setup
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Constants
|
||||
m = 10.0
|
||||
g = 9.81
|
||||
R = 1.0
|
||||
|
||||
# Time grid
|
||||
t_start, t_end, t_points = 0, 10, 1000
|
||||
t_span = torch.linspace(t_start, t_end, t_points, device=device)
|
||||
|
||||
# Specify directory for storing results
|
||||
output_dir = "average_normalized"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Use a previously generated random seed
|
||||
random_seed = 4529
|
||||
|
||||
# Set the seeds for reproducibility
|
||||
torch.manual_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# Print the chosen random seed
|
||||
print(f"Random seed for torch and numpy: {random_seed}")
|
||||
|
||||
# Initialize controller and dynamics
|
||||
controller = PendulumController().to(device)
|
||||
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
|
||||
|
||||
# Optimizer setup
|
||||
learning_rate = 1e-1
|
||||
weight_decay = 1e-4
|
||||
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
|
||||
# Training parameters
|
||||
num_epochs = 1001
|
||||
|
||||
# Define loss functions
|
||||
def make_loss_fn(weight_fn):
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
return loss_fn
|
||||
|
||||
# Define and store weight functions with descriptions, normalized by average weight
|
||||
weight_functions = {
|
||||
'constant': {
|
||||
'function': lambda t: torch.ones_like(t) / torch.ones_like(t).mean(),
|
||||
'description': 'Constant weight: All weights are 1, normalized by the average (remains 1)'
|
||||
},
|
||||
'linear': {
|
||||
'function': lambda t: (t / t.max()) / (t / t.max()).mean(),
|
||||
'description': 'Linear weight: Weights increase linearly from 0 to 1, normalized by the average weight'
|
||||
},
|
||||
'quadratic': {
|
||||
'function': lambda t: ((t / t.max()) ** 2) / ((t / t.max()) ** 2).mean(),
|
||||
'description': 'Quadratic weight: Weights increase quadratically from 0 to 1, normalized by the average weight'
|
||||
},
|
||||
'exponential': {
|
||||
'function': lambda t: (torch.exp(t / t.max() * 2)) / (torch.exp(t / t.max() * 2)).mean(),
|
||||
'description': 'Exponential weight: Weights increase exponentially, normalized by the average weight'
|
||||
},
|
||||
'inverse': {
|
||||
'function': lambda t: (1 / (t / t.max() + 1)) / (1 / (t / t.max() + 1)).mean(),
|
||||
'description': 'Inverse weight: Weights decrease inversely, normalized by the average weight'
|
||||
},
|
||||
'inverse_squared': {
|
||||
'function': lambda t: (1 / ((t / t.max() + 1) ** 2)) / (1 / ((t / t.max() + 1) ** 2)).mean(),
|
||||
'description': 'Inverse squared weight: Weights decrease inversely squared, normalized by the average weight'
|
||||
}
|
||||
}
|
||||
|
||||
# Training loop for each weight function
|
||||
for name, weight_info in weight_functions.items():
|
||||
controller = PendulumController().to(device)
|
||||
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
|
||||
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
loss_fn = make_loss_fn(weight_info['function'])
|
||||
|
||||
# File paths
|
||||
function_output_dir = os.path.join(output_dir, name)
|
||||
controllers_dir = os.path.join(function_output_dir, "controllers")
|
||||
|
||||
# Check if controllers directory exists and remove it
|
||||
if os.path.exists(controllers_dir):
|
||||
shutil.rmtree(controllers_dir)
|
||||
os.makedirs(controllers_dir, exist_ok=True)
|
||||
|
||||
config_file = os.path.join(function_output_dir, "training_config.txt")
|
||||
log_file = os.path.join(function_output_dir, "training_log.csv")
|
||||
|
||||
# Overwrite configuration and log files
|
||||
with open(config_file, "w") as f:
|
||||
f.write(f"Random Seed: {random_seed}\n")
|
||||
f.write(f"Time Span: {t_start} to {t_end}, Points: {t_points}\n")
|
||||
f.write(f"Learning Rate: {learning_rate}\n")
|
||||
f.write(f"Weight Decay: {weight_decay}\n")
|
||||
f.write("\nLoss Function:\n")
|
||||
f.write(inspect.getsource(loss_fn))
|
||||
f.write("\nTraining Cases:\n")
|
||||
f.write("[theta0, omega0, alpha0, desired_theta]\n")
|
||||
for case in state_0.cpu().numpy():
|
||||
f.write(f"{case.tolist()}\n")
|
||||
|
||||
with open(log_file, "w", newline="") as csvfile:
|
||||
csv_writer = csv.writer(csvfile)
|
||||
csv_writer.writerow(["Epoch", "Loss"])
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
optimizer.zero_grad()
|
||||
state_traj = odeint(pendulum_dynamics, state_0, t_span, method='rk4')
|
||||
loss = loss_fn(state_traj, t_span)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Logging
|
||||
with open(log_file, "a", newline="") as csvfile:
|
||||
csv_writer = csv.writer(csvfile)
|
||||
csv_writer.writerow([epoch, loss.item()])
|
||||
|
||||
# Save the model
|
||||
model_file = os.path.join(controllers_dir, f"controller_{epoch}.pth")
|
||||
torch.save(controller.state_dict(), model_file)
|
||||
print(f"{model_file} saved with loss: {loss}")
|
||||
|
||||
print("Training complete. Models and logs are saved under respective directories for each loss function.")
|
||||
27
training/normalized/initial_conditions.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
from torch import pi
|
||||
|
||||
initial_conditions = [
|
||||
[1/6 * pi, 0.0, 0.0, 0.0],
|
||||
[-1/6 * pi, 0.0, 0.0, 0.0],
|
||||
[2/3 * pi, 0.0, 0.0, 0.0],
|
||||
[-2/3 * pi, 0.0, 0.0, 0.0],
|
||||
[0.0, 1/3 * pi, 0.0, 0.0],
|
||||
[0.0, -1/3 * pi, 0.0, 0.0],
|
||||
[0.0, 2 * pi, 0.0, 0.0],
|
||||
[0.0, -2 * pi, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 2 * pi],
|
||||
[0.0, 0.0, 0.0, -2 * pi],
|
||||
[0.0, 0.0, 0.0, 1/2 * pi],
|
||||
[0.0, 0.0, 0.0, -1/2 * pi],
|
||||
[0.0, 0.0, 0.0, 1/3 * pi],
|
||||
[0.0, 0.0, 0.0, -1/3 * pi],
|
||||
[1/4 * pi, 1 * pi, 0.0, 0.0],
|
||||
[-1/4 * pi, -1 * pi, 0.0, 0.0],
|
||||
[1/2 * pi, -1 * pi, 0.0, 1/3 * pi],
|
||||
[-1/2 * pi, 1 * pi, 0.0, -1/3 * pi],
|
||||
[1/4 * pi, 1 * pi, 0.0, 2 * pi],
|
||||
[-1/4 * pi, -1 * pi, 0.0, 2 * pi],
|
||||
[1/2 * pi, -1 * pi, 0.0, 4 * pi],
|
||||
[-1/2 * pi, 1 * pi, 0.0, -4 * pi],
|
||||
]
|
||||
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1001
training/normalized/max_normalized/constant/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1001
training/normalized/max_normalized/exponential/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1001
training/normalized/max_normalized/inverse/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1001
training/normalized/max_normalized/inverse_squared/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1001
training/normalized/max_normalized/linear/training_log.csv
Normal file
@ -0,0 +1,41 @@
|
||||
Random Seed: 4529
|
||||
Time Span: 0 to 10, Points: 1000
|
||||
Learning Rate: 0.1
|
||||
Weight Decay: 0.0001
|
||||
|
||||
Loss Function:
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
Training Cases:
|
||||
[theta0, omega0, alpha0, desired_theta]
|
||||
[0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
||||
[2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[-2.094395160675049, 0.0, 0.0, 0.0]
|
||||
[0.0, 1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, -1.0471975803375244, 0.0, 0.0]
|
||||
[0.0, 6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, -6.2831854820251465, 0.0, 0.0]
|
||||
[0.0, 0.0, 0.0, 6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, -6.2831854820251465]
|
||||
[0.0, 0.0, 0.0, 1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, -1.5707963705062866]
|
||||
[0.0, 0.0, 0.0, 1.0471975803375244]
|
||||
[0.0, 0.0, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
||||
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
||||
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
||||
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|
||||
1001
training/normalized/max_normalized/quadratic/training_log.csv
Normal file
155
training/normalized/max_normalized_trainer.py
Normal file
@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torchdiffeq import odeint
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import csv
|
||||
import inspect
|
||||
|
||||
from PendulumController import PendulumController
|
||||
from PendulumDynamics import PendulumDynamics
|
||||
|
||||
# Device setup
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Initial conditions (theta0, omega0, alpha0, desired_theta)
|
||||
from initial_conditions import initial_conditions
|
||||
state_0 = torch.tensor(initial_conditions, dtype=torch.float32, device=device)
|
||||
|
||||
# Device setup
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Constants
|
||||
m = 10.0
|
||||
g = 9.81
|
||||
R = 1.0
|
||||
|
||||
# Time grid
|
||||
t_start, t_end, t_points = 0, 10, 1000
|
||||
t_span = torch.linspace(t_start, t_end, t_points, device=device)
|
||||
|
||||
# Specify directory for storing results
|
||||
output_dir = "training"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Use a previously generated random seed
|
||||
random_seed = 4529
|
||||
|
||||
# Set the seeds for reproducibility
|
||||
torch.manual_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# Print the chosen random seed
|
||||
print(f"Random seed for torch and numpy: {random_seed}")
|
||||
|
||||
# Initialize controller and dynamics
|
||||
controller = PendulumController().to(device)
|
||||
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
|
||||
|
||||
# Optimizer setup
|
||||
learning_rate = 1e-1
|
||||
weight_decay = 1e-4
|
||||
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
|
||||
# Training parameters
|
||||
num_epochs = 1000
|
||||
|
||||
# Define loss functions
|
||||
def make_loss_fn(weight_fn):
|
||||
def loss_fn(state_traj, t_span):
|
||||
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
|
||||
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
|
||||
weights = weight_fn(t_span) # Initially Size: [t_points]
|
||||
|
||||
# Reshape or expand weights to match theta dimensions
|
||||
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
|
||||
|
||||
# Calculate the weighted loss
|
||||
return torch.mean(weights * (theta - desired_theta) ** 2)
|
||||
|
||||
return loss_fn
|
||||
|
||||
# Define and store weight functions with descriptions
|
||||
weight_functions = {
|
||||
'constant': {
|
||||
'function': lambda t: torch.ones_like(t),
|
||||
'description': 'Constant weight: All weights are 1'
|
||||
},
|
||||
'linear': {
|
||||
'function': lambda t: (t / t.max()) / (t / t.max()).max(),
|
||||
'description': 'Linear weight: Weights increase linearly from 0 to 1, normalized by max'
|
||||
},
|
||||
'quadratic': {
|
||||
'function': lambda t: ((t / t.max()) ** 2) / ((t / t.max()) ** 2).max(),
|
||||
'description': 'Quadratic weight: Weights increase quadratically from 0 to 1, normalized by max'
|
||||
},
|
||||
'exponential': {
|
||||
'function': lambda t: (torch.exp(t / t.max() * 2)) / (torch.exp(t / t.max() * 2)).max(),
|
||||
'description': 'Exponential weight: Weights increase exponentially, normalized by max'
|
||||
},
|
||||
'inverse': {
|
||||
'function': lambda t: (1 / (t / t.max() + 1)) / (1 / (t / t.max() + 1)).max(),
|
||||
'description': 'Inverse weight: Weights decrease inversely, normalized by max'
|
||||
},
|
||||
'inverse_squared': {
|
||||
'function': lambda t: (1 / ((t / t.max() + 1) ** 2)) / (1 / ((t / t.max() + 1) ** 2)).max(),
|
||||
'description': 'Inverse squared weight: Weights decrease inversely squared, normalized by max'
|
||||
}
|
||||
}
|
||||
|
||||
# Training loop for each weight function
|
||||
for name, weight_info in weight_functions.items():
|
||||
controller = PendulumController().to(device)
|
||||
pendulum_dynamics = PendulumDynamics(controller, m, R, g).to(device)
|
||||
optimizer = optim.Adam(controller.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
loss_fn = make_loss_fn(weight_info['function'])
|
||||
|
||||
# File paths
|
||||
function_output_dir = os.path.join(output_dir, name)
|
||||
controllers_dir = os.path.join(function_output_dir, "controllers")
|
||||
|
||||
# Check if controllers directory exists and remove it
|
||||
if os.path.exists(controllers_dir):
|
||||
shutil.rmtree(controllers_dir)
|
||||
os.makedirs(controllers_dir, exist_ok=True)
|
||||
|
||||
config_file = os.path.join(function_output_dir, "training_config.txt")
|
||||
log_file = os.path.join(function_output_dir, "training_log.csv")
|
||||
|
||||
# Overwrite configuration and log files
|
||||
with open(config_file, "w") as f:
|
||||
f.write(f"Random Seed: {random_seed}\n")
|
||||
f.write(f"Time Span: {t_start} to {t_end}, Points: {t_points}\n")
|
||||
f.write(f"Learning Rate: {learning_rate}\n")
|
||||
f.write(f"Weight Decay: {weight_decay}\n")
|
||||
f.write("\nLoss Function:\n")
|
||||
f.write(inspect.getsource(loss_fn))
|
||||
f.write("\nTraining Cases:\n")
|
||||
f.write("[theta0, omega0, alpha0, desired_theta]\n")
|
||||
for case in state_0.cpu().numpy():
|
||||
f.write(f"{case.tolist()}\n")
|
||||
|
||||
with open(log_file, "w", newline="") as csvfile:
|
||||
csv_writer = csv.writer(csvfile)
|
||||
csv_writer.writerow(["Epoch", "Loss"])
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
optimizer.zero_grad()
|
||||
state_traj = odeint(pendulum_dynamics, state_0, t_span, method='rk4')
|
||||
loss = loss_fn(state_traj, t_span)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Logging
|
||||
with open(log_file, "a", newline="") as csvfile:
|
||||
csv_writer = csv.writer(csvfile)
|
||||
csv_writer.writerow([epoch, loss.item()])
|
||||
|
||||
# Save the model
|
||||
model_file = os.path.join(controllers_dir, f"controller_{epoch}.pth")
|
||||
torch.save(controller.state_dict(), model_file)
|
||||
print(f"{model_file} saved with loss: {loss}")
|
||||
|
||||
print("Training complete. Models and logs are saved under respective directories for each loss function.")
|
||||