Skip to content

Commit

Permalink
Fix tch view data corruption (#1434)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Mar 8, 2024
1 parent 61c0474 commit 2de270f
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 12 deletions.
97 changes: 85 additions & 12 deletions crates/burn-tch/src/tensor.rs
Expand Up @@ -9,13 +9,58 @@ use std::{marker::PhantomData, sync::Arc};
#[allow(clippy::arc_with_non_send_sync)]
pub type StorageRef = Arc<*mut c_void>;

/// A reference to a tensor storage.
#[derive(PartialEq, Debug, Clone)]
pub enum Storage {
/// When a tensor is a partial view of another tensor.
View {
/// Storage reference for the whole buffer.
buffer_ref: StorageRef,
/// Storage reference for the partial buffer.
view_ref: StorageRef,
},
/// When a tensor use all of its buffer.
Owned {
/// Storage reference for the whole buffer.
buffer_ref: StorageRef,
},
}

impl Storage {
/// Check if the storage can be used inplace.
pub fn can_mut(&self) -> bool {
match self {
Storage::View {
buffer_ref: start_ref,
view_ref,
} => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
Storage::Owned {
buffer_ref: start_ref,
} => Arc::strong_count(start_ref) == 1,
}
}

/// Get the whole buffer reference.
pub fn buffer_ref(&self) -> &StorageRef {
match self {
Storage::View {
buffer_ref: start_ref,
view_ref: _,
} => start_ref,
Storage::Owned {
buffer_ref: start_ref,
} => start_ref,
}
}
}

/// A tensor that uses the tch backend.
#[derive(Debug, PartialEq)]
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
/// Handle to the tensor. Call methods on this field.
pub tensor: tch::Tensor,
/// The tensor's storage
pub storage: StorageRef,
pub storage: Storage,
phantom: PhantomData<E>,
}

Expand All @@ -27,26 +72,49 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
/// instead.
pub fn new(tensor: tch::Tensor) -> Self {
#[allow(clippy::arc_with_non_send_sync)]
let data = Arc::new(tensor.data_ptr());
let storage = Storage::Owned {
buffer_ref: Arc::new(tensor.data_ptr()),
};

Self {
tensor,
phantom: PhantomData,
storage: data,
storage,
}
}

/// Create a tensor that was created from an operation executed on a parent tensor.
///
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
/// tracking how much tensors point to the same memory space.
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
let storage_child = tensor.data_ptr();
let mut is_a_new_tensor = true;

match &storage_parent {
Storage::View {
buffer_ref: start_ref,
view_ref,
} => {
if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
is_a_new_tensor = false;
}
}
Storage::Owned {
buffer_ref: start_ref,
} => {
if storage_child == *start_ref.as_ref() {
is_a_new_tensor = false;
}
}
};

#[allow(clippy::arc_with_non_send_sync)]
let storage = match storage_child == *storage_parent {
true => storage_parent.clone(),
false => Arc::new(storage_child),
let storage = match is_a_new_tensor {
true => Storage::Owned {
#[allow(clippy::arc_with_non_send_sync)]
buffer_ref: Arc::new(storage_child),
},
false => storage_parent.clone(),
};

Self {
Expand All @@ -57,10 +125,15 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
}

/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
pub fn partial(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
let storage = Storage::View {
buffer_ref: storage_parent.buffer_ref().clone(),
#[allow(clippy::arc_with_non_send_sync)]
view_ref: Arc::new(tensor.data_ptr()),
};
Self {
tensor,
storage: storage_parent,
storage,
phantom: PhantomData,
}
}
Expand Down Expand Up @@ -96,7 +169,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
&mut self,
func: F,
) -> Option<TchTensor<EOut, D_OUT>> {
if Arc::strong_count(&self.storage) > 1 {
if !self.storage.can_mut() {
return None;
}

Expand All @@ -113,7 +186,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
FOwn: Fn(tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor) -> tch::Tensor,
{
if Arc::strong_count(&self.storage) > 1 {
if !self.storage.can_mut() {
return TchTensor::from_existing(fref(&self.tensor), self.storage);
}

Expand Down
12 changes: 12 additions & 0 deletions crates/burn-tensor/src/tests/ops/reshape.rs
Expand Up @@ -70,6 +70,18 @@ mod tests {
assert_eq!(reshaped.shape(), [4, 3].into());
}

#[test]
fn should_not_corrupt_after_slice() {
let zeros = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
zeros.clone().slice([1..2]).reshape([1]).exp();

// May lead to zeroes being equal to [0.0, 1.0]
assert_eq!(
zeros.to_data(),
Tensor::<TestBackend, 1>::zeros([2], &Default::default()).to_data()
);
}

#[test]
#[should_panic]
fn multiple_neg_ones() {
Expand Down

0 comments on commit 2de270f

Please sign in to comment.