Base controller path: /home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/controller_base.pth Time Span: 0 to 10, Points: 1000 Learning Rate: 0.1 Weight Decay: 0.0001 Loss Function Name: two Loss Function Exponent: 2 Current Loss Function (wrapper) Source Code: def current_loss_fn(state_traj): theta = state_traj[:, :, 0] # [batch_size, t_points] desired_theta = state_traj[:, :, 3] # [batch_size, t_points] return torch.mean(loss_fn(theta, desired_theta)) Specific Loss Function Source Code: def square_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=2, min_val=min_val) Normalized Loss Function Source Code: def normalized_loss(theta: torch.Tensor, desired_theta: torch.Tensor, exponent: float, min_val: float = 0.01, delta: float = 1) -> torch.Tensor: """ Computes a normalized loss that maps the error (|theta - desired_theta|) on [0, 2π] to the range [min_val, 1]. To avoid an infinite gradient at error=0 for exponents < 1, a shift 'delta' is added. The loss is given by: loss = min_val + (1 - min_val) * ( ((error + delta)^exponent - delta^exponent) / ((2π + delta)^exponent - delta^exponent) ) so that: - When error = 0: loss = min_val - When error = 2π: loss = 1 """ error = torch.abs(theta - desired_theta) numerator = (error + delta) ** exponent - delta ** exponent denominator = (2 * math.pi + delta) ** exponent - delta ** exponent return min_val + (1 - min_val) * (numerator / denominator) Training Cases: [theta0, omega0, alpha0, desired_theta] [0.5235987901687622, 0.0, 0.0, 0.0] [-0.5235987901687622, 0.0, 0.0, 0.0] [2.094395160675049, 0.0, 0.0, 0.0] [-2.094395160675049, 0.0, 0.0, 0.0] [0.0, 1.0471975803375244, 0.0, 0.0] [0.0, -1.0471975803375244, 0.0, 0.0] [0.0, 6.2831854820251465, 0.0, 0.0] [0.0, -6.2831854820251465, 0.0, 0.0] [0.0, 0.0, 0.0, 6.2831854820251465] [0.0, 0.0, 0.0, -6.2831854820251465] [0.0, 0.0, 0.0, 1.5707963705062866] [0.0, 0.0, 0.0, -1.5707963705062866] [0.0, 0.0, 0.0, 1.0471975803375244] [0.0, 0.0, 0.0, -1.0471975803375244] [0.7853981852531433, 3.1415927410125732, 0.0, 0.0] [-0.7853981852531433, -3.1415927410125732, 0.0, 0.0] [1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244] [-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244] [0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465] [-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465] [1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293] [-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]