52 lines
2.0 KiB
Python
52 lines
2.0 KiB
Python
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
|
|
def plot_3d_epoch_evolution(epochs, theta_over_epochs, desired_theta, save_path, title, num_steps, dt):
|
|
fig = plt.figure(figsize=(7, 5))
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
time_steps = np.arange(num_steps) * dt
|
|
|
|
theta_values = np.concatenate(theta_over_epochs)
|
|
theta_min = np.min(theta_values)
|
|
theta_max = np.max(theta_values)
|
|
|
|
desired_range_min = desired_theta - 1.5 * np.pi
|
|
desired_range_max = desired_theta + 1.5 * np.pi
|
|
desired_range_min = max(theta_min, desired_range_min)
|
|
desired_range_max = min(theta_max, desired_range_max)
|
|
|
|
for epoch, theta_vals in reversed(list(zip(epochs, theta_over_epochs))):
|
|
clipped_theta_vals = np.clip(theta_vals, desired_range_min, desired_range_max)
|
|
ax.plot([epoch] * len(time_steps), time_steps, clipped_theta_vals)
|
|
|
|
epochs_array = np.array([epoch for epoch, _ in zip(epochs, theta_over_epochs)])
|
|
ax.plot(epochs_array, [time_steps.max()] * len(epochs_array), [desired_theta] * len(epochs_array),
|
|
color='r', linestyle='--', linewidth=2, label='Desired Theta at End Time')
|
|
|
|
ax.set_xlabel("Epoch")
|
|
ax.set_ylabel("Time (s)")
|
|
ax.set_zlabel("Theta (rad)")
|
|
ax.set_title(title)
|
|
ax.set_zlim(desired_range_min, desired_range_max)
|
|
ax.view_init(elev=20, azim=-135)
|
|
|
|
if not os.path.exists(os.path.dirname(save_path)):
|
|
os.makedirs(os.path.dirname(save_path))
|
|
plt.savefig(save_path, dpi=300)
|
|
plt.close()
|
|
print(f"Saved plot as '{save_path}'.")
|
|
|
|
def plot_final_theta_vs_epoch(epochs, final_thetas, loss_functions, save_path):
|
|
plt.figure()
|
|
for final_theta, label in zip(final_thetas, loss_functions):
|
|
plt.plot(epochs, final_theta, label=label)
|
|
plt.xlabel("Epoch")
|
|
plt.ylabel("Final Theta (rad)")
|
|
plt.legend()
|
|
if not os.path.exists(os.path.dirname(save_path)):
|
|
os.makedirs(os.path.dirname(save_path))
|
|
plt.savefig(save_path)
|
|
plt.close()
|