Skip to content

Commit

Permalink
Add prod and prod_dim tensor ops (#1460)
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora committed Mar 12, 2024
1 parent 80aac1d commit 7a98b2f
Show file tree
Hide file tree
Showing 20 changed files with 452 additions and 68 deletions.
116 changes: 59 additions & 57 deletions burn-book/src/building-blocks/tensor.md
Expand Up @@ -134,37 +134,37 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t

Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.

| Burn | PyTorch Equivalent |
| ------------------------------------- | -------------------------------------------- |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |

### Numeric Operations

Expand Down Expand Up @@ -203,13 +203,13 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
| `tensor.max()` | `tensor.max()` |
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
| `tensor.max_dim(dim)` | `tensor.max(dim, keepdim=True)` |
| `tensor.max_dim_with_indices(dim)` | N/A |
| `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` |
| `tensor.mean()` | `tensor.mean()` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim, keepdim=True)` |
| `tensor.min()` | `tensor.min()` |
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
| `tensor.min_dim(dim)` | `tensor.min(dim, keepdim=True)` |
| `tensor.min_dim_with_indices(dim)` | N/A |
| `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` |
| `tensor.mul(other)` or `tensor * other` | `tensor * other` |
Expand All @@ -218,14 +218,16 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
| `tensor.prod()` | `tensor.prod()` |
| `tensor.prod_dim(dim)` | `tensor.prod(dim, keepdim=True)` |
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
| `tensor.sign()` | `tensor.sign()` |
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
| `tensor.sum()` | `tensor.sum()` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |

Expand Down Expand Up @@ -269,35 +271,35 @@ Those operations are only available for `Int` tensors.
| ------------------------------------------------ | ------------------------------------------------------- |
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |

# Bool Operations

Those operations are only available for `Bool` tensors.

| Burn API | PyTorch Equivalent |
| ------------------- | ----------------------------------- |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.not()` | `tensor.logical_not()` |
| `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
| Burn API | PyTorch Equivalent |
| ------------------- | ------------------------------- |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.int()` | `tensor.to(torch.long)` |
| `tensor.not()` | `tensor.logical_not()` |
| `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |

## Activation Functions

| Burn API | PyTorch Equivalent |
| ---------------------------------------- | ----------------------------------------------------- |
| `activation::gelu(tensor)` | Similar to `nn.functional.gelu(tensor)` |
| `activation::log_sigmoid(tensor)` | Similar to `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | Similar to `nn.functional.log_softmax(tensor, dim)` |
| `activation::mish(tensor)` | Similar to `nn.functional.mish(tensor)` |
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |
| `activation::quiet_softmax(tensor, dim)` | Similar to `nn.functional.quiet_softmax(tensor, dim)` |
| `activation::relu(tensor)` | Similar to `nn.functional.relu(tensor)` |
| `activation::sigmoid(tensor)` | Similar to `nn.functional.sigmoid(tensor)` |
| `activation::silu(tensor)` | Similar to `nn.functional.silu(tensor)` |
| `activation::softmax(tensor, dim)` | Similar to `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | Similar to `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | Similar to `nn.functional.tanh(tensor)` |
| Burn API | PyTorch Equivalent |
| ---------------------------------------- | ------------------------------------------ |
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |
| `activation::mish(tensor)` | `nn.functional.mish(tensor)` |
| `activation::prelu(tensor,alpha)` | `nn.functional.prelu(tensor,weight)` |
| `activation::quiet_softmax(tensor, dim)` | `nn.functional.quiet_softmax(tensor, dim)` |
| `activation::relu(tensor)` | `nn.functional.relu(tensor)` |
| `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` |
| `activation::silu(tensor)` | `nn.functional.silu(tensor)` |
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |
8 changes: 8 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Expand Up @@ -356,4 +356,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
B::int_sign(tensor)
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
B::int_prod(tensor)
}

fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
B::int_prod_dim(tensor, dim)
}
}
3 changes: 3 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Expand Up @@ -2379,6 +2379,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.parents([&tensor])
.stateless(B::float_sign(tensor.primitive))
}

// TODO: Implement float_prod and float_sum
// https://github.com/tracel-ai/burn/issues/1458
}

#[derive(Debug, Clone)]
Expand Down
8 changes: 8 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Expand Up @@ -313,6 +313,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!("prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
}

fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!("prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
}

fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
panic!("Not supported by Candle")
Expand Down
41 changes: 41 additions & 0 deletions crates/burn-fusion/src/ops/int.rs
Expand Up @@ -1060,6 +1060,47 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(ProdOps, B::int_prod);

let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(vec![1]);

let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())),
ProdOps::<D>::new(desc),
);

out
}

fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
scalar_int_ops!(ProdDimOps, B::int_prod_dim, usize, noconvert);

let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())),
ProdDimOps::<D>::new(desc),
);

out
}

fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(MeanOps, B::int_mean);

Expand Down
13 changes: 13 additions & 0 deletions crates/burn-fusion/src/stream/context.rs
Expand Up @@ -614,6 +614,19 @@ impl<E: Element> NumericOperationDescription<E> {
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::Prod(desc) => {
NumericOperationDescription::Prod(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::ProdDim(desc) => {
NumericOperationDescription::ProdDim(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs, // Dim should stay the same.
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::EqualElem(desc) => {
NumericOperationDescription::EqualElem(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
Expand Down

0 comments on commit 7a98b2f

Please sign in to comment.