Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Huber loss #1444

Merged
merged 3 commits into from Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/module.md
Expand Up @@ -162,3 +162,4 @@ Burn comes with built-in modules that you can use to build your own modules.
| ------------------ | --------------------- |
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MseLoss` | `nn.MSELoss` |
| `HuberLoss` | `nn.HuberLoss` |
179 changes: 179 additions & 0 deletions crates/burn-core/src/nn/loss/huber.rs
@@ -0,0 +1,179 @@
use crate as burn;

use crate::{config::Config, module::Module};
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;
use core::marker::PhantomData;

use super::Reduction;

/// Configuration to create a [Huber loss](HuberLoss).
#[derive(Config, Debug)]
pub struct HuberLossConfig {
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
pub delta: f32,
}

impl HuberLossConfig {
/// Initialize [Huber loss](HuberLoss).
pub fn init<B: Backend>(&self, device: &B::Device) -> HuberLoss<B> {
// device is not needed as of now, but we might want to prepare some data on it
// and its consistent with other loss functions
let _ = device;
self.assertions();
HuberLoss {
delta: self.delta,
lin_bias: self.delta * self.delta * 0.5,
_backend: PhantomData,
}
}

fn assertions(&self) {
assert!(
self.delta >= 0., // This also tests for normality
"Delta for Huber loss must be a non-negative number."
);
}
}

/// Calculate the Huber loss between the inputs and the target.
///
/// The loss for each element of the residuals `res = targets - predictions` is given by
///
/// ```text
/// L(r) = 0.5 * x^2 if |r| <= d
/// L(r) = 0.5 * d^2 + d * (|err| - d) if |r| > d
/// ```
///
/// where `d` is the configured `delta`. In particular, this is equal to the
/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`,
/// but behaves linearly instead of quadratically for large residuals.
///
/// This loss function is less sensitive to outliers than the mean squared error loss.
///
/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
#[derive(Module, Debug)]
pub struct HuberLoss<B: Backend> {
delta: f32,
lin_bias: f32, // delta * delta * 0.5 precomputed
_backend: PhantomData<B>,
}

impl<B: Backend> HuberLoss<B> {
/// Compute the loss element-wise for the predictions and targets, then reduce
/// to a single loss value.
///
/// `Reduction::Auto` behaves as `Reduction::Mean`.
///
/// # Shapes
///
/// - predictions: \[...dims\]
/// - targets: \[...dims\]
/// - output: \[1\]
pub fn forward<const D: usize>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => loss.mean(),
Reduction::Sum => loss.sum(),
}
}
/// Compute the loss element-wise for the predictions and targets.
///
/// # Shapes
///
/// - predictions: [...dims]
/// - targets: [...dims]
/// - output: [...dims]
pub fn forward_no_reduction<const D: usize>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
let residuals = targets - predictions;
self.forward_residuals(residuals)
}
/// Compute the loss element-wise for the given residuals.
///
/// # Shapes
///
/// - residuals: [...dims]
/// - output: [...dims]
pub fn forward_residuals<const D: usize>(&self, residuals: Tensor<B, D>) -> Tensor<B, D> {
let is_large = residuals.clone().abs().greater_elem(self.delta);
// We are interested in `sign(res)` where `abs(res) > self.delta`. Note that the
// `sign()` function, in general, suffers from being undefined at 0 and is not even
// implemented in all backends!
// Instead the following tensor implements `delta * sign(res)` for values outside
// the bound:
let softsign = residuals.clone().clamp(-self.delta, self.delta);

// 0.5 * d^2 + d * (|res| - d) =
// d * |res| - 0.5 * d^2
// Moreover |res| = sign(res) * res
let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);

let inside = residuals.powf_scalar(2.).mul_scalar(0.5);
inside.mask_where(is_large, outside)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn_tensor::Data;
type TestTensor<const D: usize> = Tensor<TestBackend, D>;

#[test]
fn test_huber_loss() {
let predict = Data::from([-2., -0.5, 0., 0.3, 1.]);
let targets = Data::from([0., 0., 0., 0., 0.]);

let device = Default::default();

let predict = TestTensor::<1>::from_data(predict, &device);
let targets = TestTensor::<1>::from_data(targets, &device);

let huber = HuberLossConfig::new(0.5).init(&device);

let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
let loss_no_reduction = huber.forward_no_reduction(predict, targets);

loss_no_reduction
.into_data()
.assert_approx_eq(&Data::from([0.875, 0.125, 0., 0.045, 0.375]), 7);
loss.into_data().assert_approx_eq(&Data::from([0.284]), 7);
loss_sum
.into_data()
.assert_approx_eq(&Data::from([1.42]), 7);
}

#[cfg(feature = "std")]
#[test]
fn test_huber_ad_loss() {
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;

let predict = Data::from([-2., -0.5, 0., 0.3, 1.]);
let targets = Data::from([0., 0., 0., 0., 0.]);

let device = Default::default();
let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
let targets = TestAutodiffTensor::from_data(targets, &device);

let loss = HuberLossConfig::new(0.5).init(&device);
let loss = loss.forward_no_reduction(predict.clone(), targets);

let grads = loss.backward();
let grads_predict = predict.grad(&grads).unwrap();

grads_predict
.to_data()
.assert_approx_eq(&Data::from([-0.5, -0.5, 0., 0.3, 0.5]), 3);
}
}
2 changes: 2 additions & 0 deletions crates/burn-core/src/nn/loss/mod.rs
@@ -1,9 +1,11 @@
mod binary_cross_entropy;
mod cross_entropy;
mod huber;
mod mse;
mod reduction;

pub use binary_cross_entropy::*;
pub use cross_entropy::*;
pub use huber::*;
pub use mse::*;
pub use reduction::*;