64 lines
2.7 KiB
Plaintext
64 lines
2.7 KiB
Plaintext
Base controller path: /home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/controller_base.pth
|
|
Time Span: 0 to 10, Points: 1000
|
|
Learning Rate: 0.1
|
|
Weight Decay: 0.0001
|
|
|
|
Loss Function Name: one_half
|
|
Loss Function Exponent: 0.5
|
|
|
|
Current Loss Function (wrapper) Source Code:
|
|
def current_loss_fn(state_traj):
|
|
theta = state_traj[:, :, 0] # [batch_size, t_points]
|
|
desired_theta = state_traj[:, :, 3] # [batch_size, t_points]
|
|
return torch.mean(loss_fn(theta, desired_theta))
|
|
|
|
Specific Loss Function Source Code:
|
|
def one_half_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor:
|
|
return normalized_loss(theta, desired_theta, exponent=1/2, min_val=min_val)
|
|
|
|
Normalized Loss Function Source Code:
|
|
def normalized_loss(theta: torch.Tensor, desired_theta: torch.Tensor, exponent: float, min_val: float = 0.01, delta: float = 1) -> torch.Tensor:
|
|
"""
|
|
Computes a normalized loss that maps the error (|theta - desired_theta|) on [0, 2π]
|
|
to the range [min_val, 1]. To avoid an infinite gradient at error=0 for exponents < 1,
|
|
a shift 'delta' is added.
|
|
|
|
The loss is given by:
|
|
|
|
loss = min_val + (1 - min_val) * ( ((error + delta)^exponent - delta^exponent)
|
|
/ ((2π + delta)^exponent - delta^exponent) )
|
|
|
|
so that:
|
|
- When error = 0: loss = min_val
|
|
- When error = 2π: loss = 1
|
|
"""
|
|
error = torch.abs(theta - desired_theta)
|
|
numerator = (error + delta) ** exponent - delta ** exponent
|
|
denominator = (2 * math.pi + delta) ** exponent - delta ** exponent
|
|
return min_val + (1 - min_val) * (numerator / denominator)
|
|
|
|
Training Cases:
|
|
[theta0, omega0, alpha0, desired_theta]
|
|
[0.5235987901687622, 0.0, 0.0, 0.0]
|
|
[-0.5235987901687622, 0.0, 0.0, 0.0]
|
|
[2.094395160675049, 0.0, 0.0, 0.0]
|
|
[-2.094395160675049, 0.0, 0.0, 0.0]
|
|
[0.0, 1.0471975803375244, 0.0, 0.0]
|
|
[0.0, -1.0471975803375244, 0.0, 0.0]
|
|
[0.0, 6.2831854820251465, 0.0, 0.0]
|
|
[0.0, -6.2831854820251465, 0.0, 0.0]
|
|
[0.0, 0.0, 0.0, 6.2831854820251465]
|
|
[0.0, 0.0, 0.0, -6.2831854820251465]
|
|
[0.0, 0.0, 0.0, 1.5707963705062866]
|
|
[0.0, 0.0, 0.0, -1.5707963705062866]
|
|
[0.0, 0.0, 0.0, 1.0471975803375244]
|
|
[0.0, 0.0, 0.0, -1.0471975803375244]
|
|
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
|
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
|
|
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
|
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
|
|
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
|
|
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
|
|
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
|
|
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]
|