From 53eb3ecfa9f11d95be6bb97f33bb1d8cc2e7c510 Mon Sep 17 00:00:00 2001 From: WorldSEnder Date: Wed, 13 Mar 2024 17:55:46 +0000 Subject: [PATCH] Implement Huber loss (#1444) * Implement Huber loss Instead of using a sign or abs function, uses clamping to compute it outside the bounds. This is better for the autodiff backend. * mention Huber loss in the book * unify naming of residuals in comments --- burn-book/src/building-blocks/module.md | 1 + crates/burn-core/src/nn/loss/huber.rs | 178 ++++++++++++++++++++++++ crates/burn-core/src/nn/loss/mod.rs | 2 + 3 files changed, 181 insertions(+) create mode 100644 crates/burn-core/src/nn/loss/huber.rs diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index a445c012da..e1a09c9019 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -162,3 +162,4 @@ Burn comes with built-in modules that you can use to build your own modules. | ------------------ | --------------------- | | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `MseLoss` | `nn.MSELoss` | +| `HuberLoss` | `nn.HuberLoss` | diff --git a/crates/burn-core/src/nn/loss/huber.rs b/crates/burn-core/src/nn/loss/huber.rs new file mode 100644 index 0000000000..055910ed41 --- /dev/null +++ b/crates/burn-core/src/nn/loss/huber.rs @@ -0,0 +1,178 @@ +use crate as burn; + +use crate::{config::Config, module::Module}; +use burn_tensor::backend::Backend; +use burn_tensor::Tensor; +use core::marker::PhantomData; + +use super::Reduction; + +/// Configuration to create a [Huber loss](HuberLoss). +#[derive(Config, Debug)] +pub struct HuberLossConfig { + /// The bound where the Huber loss function changes from quadratic to linear behaviour. + pub delta: f32, +} + +impl HuberLossConfig { + /// Initialize [Huber loss](HuberLoss). + pub fn init(&self, device: &B::Device) -> HuberLoss { + // device is not needed as of now, but we might want to prepare some data on it + // and its consistent with other loss functions + let _ = device; + self.assertions(); + HuberLoss { + delta: self.delta, + lin_bias: self.delta * self.delta * 0.5, + _backend: PhantomData, + } + } + + fn assertions(&self) { + assert!( + self.delta >= 0., // This also tests for normality + "Delta for Huber loss must be a non-negative number." + ); + } +} + +/// Calculate the Huber loss between the inputs and the target. +/// +/// The loss for each element of the residuals `r = targets - predictions` is given by +/// +/// ```text +/// L(r) = 0.5 * x^2 if |r| <= d +/// L(r) = 0.5 * d^2 + d * (|r| - d) if |r| > d +/// ``` +/// +/// where `d` is the configured `delta`. In particular, this is equal to the +/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`, +/// but behaves linearly instead of quadratically for large residuals. +/// +/// This loss function is less sensitive to outliers than the mean squared error loss. +/// +/// See also: +#[derive(Module, Debug)] +pub struct HuberLoss { + delta: f32, + lin_bias: f32, // delta * delta * 0.5 precomputed + _backend: PhantomData, +} + +impl HuberLoss { + /// Compute the loss element-wise for the predictions and targets, then reduce + /// to a single loss value. + /// + /// `Reduction::Auto` behaves as `Reduction::Mean`. + /// + /// # Shapes + /// + /// - predictions: \[...dims\] + /// - targets: \[...dims\] + /// - output: \[1\] + pub fn forward( + &self, + predictions: Tensor, + targets: Tensor, + reduction: Reduction, + ) -> Tensor { + let loss = self.forward_no_reduction(predictions, targets); + match reduction { + Reduction::Mean | Reduction::Auto => loss.mean(), + Reduction::Sum => loss.sum(), + } + } + /// Compute the loss element-wise for the predictions and targets. + /// + /// # Shapes + /// + /// - predictions: [...dims] + /// - targets: [...dims] + /// - output: [...dims] + pub fn forward_no_reduction( + &self, + predictions: Tensor, + targets: Tensor, + ) -> Tensor { + let residuals = targets - predictions; + self.forward_residuals(residuals) + } + /// Compute the loss element-wise for the given residuals. + /// + /// # Shapes + /// + /// - residuals: [...dims] + /// - output: [...dims] + pub fn forward_residuals(&self, residuals: Tensor) -> Tensor { + let is_large = residuals.clone().abs().greater_elem(self.delta); + // We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the + // `sign()` function, in general, suffers from a jump at 0. + // Instead the following tensor implements `delta * sign(r)` for values outside + // the bound: + let softsign = residuals.clone().clamp(-self.delta, self.delta); + + // 0.5 * d^2 + d * (|r| - d) = + // d * |r| - 0.5 * d^2 + // Moreover |r| = sign(r) * r + let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias); + + let inside = residuals.powf_scalar(2.).mul_scalar(0.5); + inside.mask_where(is_large, outside) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + type TestTensor = Tensor; + + #[test] + fn test_huber_loss() { + let predict = Data::from([-2., -0.5, 0., 0.3, 1.]); + let targets = Data::from([0., 0., 0., 0., 0.]); + + let device = Default::default(); + + let predict = TestTensor::<1>::from_data(predict, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let huber = HuberLossConfig::new(0.5).init(&device); + + let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum); + let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = huber.forward_no_reduction(predict, targets); + + loss_no_reduction + .into_data() + .assert_approx_eq(&Data::from([0.875, 0.125, 0., 0.045, 0.375]), 7); + loss.into_data().assert_approx_eq(&Data::from([0.284]), 7); + loss_sum + .into_data() + .assert_approx_eq(&Data::from([1.42]), 7); + } + + #[cfg(feature = "std")] + #[test] + fn test_huber_ad_loss() { + type TestAutodiffTensor = Tensor; + + let predict = Data::from([-2., -0.5, 0., 0.3, 1.]); + let targets = Data::from([0., 0., 0., 0., 0.]); + + let device = Default::default(); + let predict = TestAutodiffTensor::from_data(predict, &device).require_grad(); + let targets = TestAutodiffTensor::from_data(targets, &device); + + let loss = HuberLossConfig::new(0.5).init(&device); + let loss = loss.forward_no_reduction(predict.clone(), targets); + + let grads = loss.backward(); + let grads_predict = predict.grad(&grads).unwrap(); + + grads_predict + .to_data() + .assert_approx_eq(&Data::from([-0.5, -0.5, 0., 0.3, 0.5]), 3); + } +} diff --git a/crates/burn-core/src/nn/loss/mod.rs b/crates/burn-core/src/nn/loss/mod.rs index 5b37df84f2..cca7b4541b 100644 --- a/crates/burn-core/src/nn/loss/mod.rs +++ b/crates/burn-core/src/nn/loss/mod.rs @@ -1,9 +1,11 @@ mod binary_cross_entropy; mod cross_entropy; +mod huber; mod mse; mod reduction; pub use binary_cross_entropy::*; pub use cross_entropy::*; +pub use huber::*; pub use mse::*; pub use reduction::*;