import torch import math import matplotlib.pyplot as plt def normalized_loss(theta: torch.Tensor, desired_theta: torch.Tensor, exponent: float, min_val: float = 0.01, delta: float = 1) -> torch.Tensor: """ Computes a normalized loss that maps the error (|theta - desired_theta|) on [0, 2π] to the range [min_val, 1]. To avoid an infinite gradient at error=0 for exponents < 1, a shift 'delta' is added. The loss is given by: loss = min_val + (1 - min_val) * ( ((error + delta)^exponent - delta^exponent) / ((2π + delta)^exponent - delta^exponent) ) so that: - When error = 0: loss = min_val - When error = 2π: loss = 1 """ error = torch.abs(theta - desired_theta) numerator = (error + delta) ** exponent - delta ** exponent denominator = (2 * math.pi + delta) ** exponent - delta ** exponent return min_val + (1 - min_val) * (numerator / denominator) # Existing loss functions def one_fourth_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=1/4, min_val=min_val) def one_half_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=1/2, min_val=min_val) def abs_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=1, min_val=min_val) def square_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=2, min_val=min_val) def fourth_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=4, min_val=min_val) # New loss functions def one_third_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=1/3, min_val=min_val) def one_fifth_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=1/5, min_val=min_val) def three_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=3, min_val=min_val) def five_loss(theta: torch.Tensor, desired_theta: torch.Tensor, min_val: float = 0.01) -> torch.Tensor: return normalized_loss(theta, desired_theta, exponent=5, min_val=min_val) # Dictionary mapping function names to a tuple of (exponent, function) base_loss_functions = { 'one_fifth': (1/5, one_fifth_loss), 'one_fourth': (1/4, one_fourth_loss), 'one_third': (1/3, one_third_loss), 'one_half': (1/2, one_half_loss), 'one': (1, abs_loss), 'two': (2, square_loss), 'three': (3, three_loss), 'four': (4, fourth_loss), 'five': (5, five_loss), } if __name__ == "__main__": # Create an array of error values from 0 to 2π errors = torch.linspace(0, 2 * math.pi, 1000) desired = torch.zeros_like(errors) # Assume desired_theta = 0 for plotting plt.figure(figsize=(10, 6)) for name, (exponent, loss_fn) in base_loss_functions.items(): # Compute loss for each error value losses = loss_fn(errors, desired, min_val=0.01) plt.plot(errors.numpy(), losses.numpy(), label=f"{name} (exp={exponent})") plt.xlabel("Error (|theta - desired_theta|)") plt.ylabel("Normalized Loss") plt.title("Shifted + Normalized Base Loss Functions on Domain [0, 2π]") plt.legend() plt.grid(True) plt.savefig("test.png") plt.show()