Base controller path: /home/judson/Neural-Networks-in-GNC/inverted_pendulum/training/controller_base.pth Time Span: 0 to 10, Points: 1000 Learning Rate: 0.1 Weight Decay: 0.0001 Loss Function: def loss_fn(state_traj, t_span): theta = state_traj[:, :, 0] # Size: [batch_size, t_points] desired_theta = state_traj[:, :, 3] # Size: [batch_size, t_points] min_weight = 0.01 # Weights are on the range [min_weight, 1] weights = weight_fn(t_span, min_val=min_weight) # Initially Size: [t_points] # Reshape or expand weights to match theta dimensions weights = weights.view(-1, 1) # Now Size: [batch_size, t_points] # Calculate the weighted loss return torch.mean(weights * (theta - desired_theta) ** 2) Weight Function: def quadratic(t_span: Union[torch.Tensor, List[float]], t_max: float = None, min_val: float = 0.01) -> torch.Tensor: t_span = t_span.clone().detach() if isinstance(t_span, torch.Tensor) else torch.tensor(t_span) t_max = t_max if t_max is not None else t_span[-1] return min_val + ((1 - min_val) / (t_max)**2) * t_span**2 Training Cases: [theta0, omega0, alpha0, desired_theta] [0.5235987901687622, 0.0, 0.0, 0.0] [-0.5235987901687622, 0.0, 0.0, 0.0] [2.094395160675049, 0.0, 0.0, 0.0] [-2.094395160675049, 0.0, 0.0, 0.0] [0.0, 1.0471975803375244, 0.0, 0.0] [0.0, -1.0471975803375244, 0.0, 0.0] [0.0, 6.2831854820251465, 0.0, 0.0] [0.0, -6.2831854820251465, 0.0, 0.0] [0.0, 0.0, 0.0, 6.2831854820251465] [0.0, 0.0, 0.0, -6.2831854820251465] [0.0, 0.0, 0.0, 1.5707963705062866] [0.0, 0.0, 0.0, -1.5707963705062866] [0.0, 0.0, 0.0, 1.0471975803375244] [0.0, 0.0, 0.0, -1.0471975803375244] [0.7853981852531433, 3.1415927410125732, 0.0, 0.0] [-0.7853981852531433, -3.1415927410125732, 0.0, 0.0] [1.5707963705062866, -3.1415927410125732, 0.0, 1.0471975803375244] [-1.5707963705062866, 3.1415927410125732, 0.0, -1.0471975803375244] [0.7853981852531433, 3.1415927410125732, 0.0, 6.2831854820251465] [-0.7853981852531433, -3.1415927410125732, 0.0, 6.2831854820251465] [1.5707963705062866, -3.1415927410125732, 0.0, 12.566370964050293] [-1.5707963705062866, 3.1415927410125732, 0.0, -12.566370964050293]