Inverted-Pendulum-Neural-Ne.../training/base_loss/five/training_config.txt

64 lines
2.7 KiB
Plaintext

Base controller path: /home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/controller_base.pth
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function Name: five
Loss Function Exponent: 5
Current Loss Function (wrapper) Source Code:
def current_loss_fn(state_traj):
theta = state_traj[:, :, 0] # [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # [batch_size, t_points]
return torch.mean(loss_fn(theta, desired_theta))
Specific Loss Function Source Code:
def five_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor:
return normalized_loss(theta, desired_theta, exponent=5, min_val=min_val)
Normalized Loss Function Source Code:
def normalized_loss(theta: torch.Tensor, desired_theta: torch.Tensor, exponent: float, min_val: float = 0.01, delta: float = 1) -> torch.Tensor:
"""
Computes a normalized loss that maps the error (|theta - desired_theta|) on [0, 2π]
to the range [min_val, 1]. To avoid an infinite gradient at error=0 for exponents < 1,
a shift 'delta' is added.
The loss is given by:
loss = min_val + (1 - min_val) * ( ((error + delta)^exponent - delta^exponent)
/ ((2π + delta)^exponent - delta^exponent) )
so that:
- When error = 0: loss = min_val
- When error = 2π: loss = 1
"""
error = torch.abs(theta - desired_theta)
numerator = (error + delta) ** exponent - delta ** exponent
denominator = (2 * math.pi + delta) ** exponent - delta ** exponent
return min_val + (1 - min_val) * (numerator / denominator)
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]