Inverted-Pendulum-Neural-Ne.../time_weighting/inverse_cubed/training_config.txt

49 lines
2.2 KiB
Plaintext

Base controller path: /home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/controller_base.pth
Time Span: 0 to 10, Points: 1000
Learning Rate: 0.1
Weight Decay: 0.0001
Loss Function:
def loss_fn(state_traj, t_span):
theta = state_traj[:, :, 0] # Size: [batch_size, t_points]
desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points]
min_weight = 0.01 # Weights are on the range [min_weight, 1]
weights = weight_fn(t_span, min_val=min_weight) # Initially Size: [t_points]
# Reshape or expand weights to match theta dimensions
weights = weights.view(-1, 1) # Now Size: [batch_size, t_points]
# Calculate the weighted loss
return torch.mean(weights * (theta - desired_theta) ** 2)
Weight Function:
def linear_mirrored(t_span: Union[torch.Tensor, List[float]], t_max: float = None, min_val: float = 0.01) -> torch.Tensor:
t_span = t_span.clone().detach() if isinstance(t_span, torch.Tensor) else torch.tensor(t_span)
t_max = t_max if t_max is not None else t_span[-1]
return min_val + ((1 - min_val) / (t_max)**1) * (-t_span + t_max)**1
Training Cases:
[theta0, omega0, alpha0, desired_theta]
[0.5235987901687622, 0.0, 0.0, 0.0]
[-0.5235987901687622, 0.0, 0.0, 0.0]
[2.094395160675049, 0.0, 0.0, 0.0]
[-2.094395160675049, 0.0, 0.0, 0.0]
[0.0, 1.0471975803375244, 0.0, 0.0]
[0.0, -1.0471975803375244, 0.0, 0.0]
[0.0, 6.2831854820251465, 0.0, 0.0]
[0.0, -6.2831854820251465, 0.0, 0.0]
[0.0, 0.0, 0.0, 6.2831854820251465]
[0.0, 0.0, 0.0, -6.2831854820251465]
[0.0, 0.0, 0.0, 1.5707963705062866]
[0.0, 0.0, 0.0, -1.5707963705062866]
[0.0, 0.0, 0.0, 1.0471975803375244]
[0.0, 0.0, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 0.0]
[-0.7853981852531433, -3.1415927410125732, 0.0, 0.0]
[1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244]
[-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244]
[0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465]
[-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465]
[1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293]
[-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]