Skip to content

Commit

Permalink
Retain original error when rewinding after a bad parse
Browse files Browse the repository at this point in the history
  • Loading branch information
csnover committed Aug 9, 2023
1 parent 2be3ff7 commit 6d1bdc3
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 44 deletions.
7 changes: 2 additions & 5 deletions binrw/src/binread/impls.rs
@@ -1,5 +1,5 @@
use crate::{
io::{self, Read, Seek, SeekFrom},
io::{self, Read, Seek},
BinRead, BinResult, Endian, Error, NamedArgs,
};
use alloc::{boxed::Box, vec::Vec};
Expand All @@ -18,10 +18,7 @@ macro_rules! binread_impl {
let mut val = [0; core::mem::size_of::<$type_name>()];
let pos = reader.stream_position()?;

reader.read_exact(&mut val).or_else(|e| {
reader.seek(SeekFrom::Start(pos))?;
Err(e)
})?;
reader.read_exact(&mut val).or_else(crate::__private::restore_position(reader, pos))?;
Ok(match endian {
Endian::Big => {
<$type_name>::from_be_bytes(val)
Expand Down
41 changes: 39 additions & 2 deletions binrw/src/private.rs
@@ -1,6 +1,6 @@
use crate::{
error::CustomError,
io::{Read, Seek, Write},
error::{Backtrace, BacktraceFrame, CustomError},
io::{Read, Seek, SeekFrom, Write},
BinRead, BinResult, BinWrite, Endian, Error,
};
use alloc::{boxed::Box, string::String};
Expand Down Expand Up @@ -166,6 +166,43 @@ where
args
}

pub fn restore_position<E: Into<Error>, S: Seek, T>(
stream: &mut S,
pos: u64,
) -> impl FnOnce(E) -> BinResult<T> + '_ {
move |error| match stream.seek(SeekFrom::Start(pos)) {
Ok(_) => Err(error.into()),
Err(seek_error) => Err(restore_position_err(error.into(), seek_error.into())),
}
}

fn restore_position_err(error: Error, mut seek_error: Error) -> Error {
let reason = BacktraceFrame::Message("rewinding after a failure".into());
match error {
Error::Backtrace(mut bt) => {
core::mem::swap(&mut seek_error, &mut *bt.error);
bt.frames.insert(0, seek_error.into());
bt.frames.insert(0, reason);
Error::Backtrace(bt)
}
error => Error::Backtrace(Backtrace::new(
seek_error,
alloc::vec![reason, error.into()],
)),
}
}

pub fn restore_position_variant<S: Seek>(
stream: &mut S,
pos: u64,
error: Error,
) -> BinResult<Error> {
match stream.seek(SeekFrom::Start(pos)) {
Ok(_) => Ok(error),
Err(seek_error) => Err(restore_position_err(error, seek_error.into())),
}
}

pub fn write_try_map_args_type_hint<Input, Output, Error, MapFn, Args>(
_: &MapFn,
args: Args,
Expand Down
111 changes: 111 additions & 0 deletions binrw/tests/error.rs
Expand Up @@ -158,6 +158,117 @@ fn not_custom_error() {
assert!(err.custom_err::<i32>().is_none());
}

#[test]
fn no_seek_struct() {
use binrw::{
error::BacktraceFrame,
io::{Cursor, NoSeek},
BinRead,
};

#[derive(BinRead, Debug)]
struct Test {
#[br(assert(_a == 1))]
_a: u32,
}

let mut data = NoSeek::new(Cursor::new(b"\0\0\0\0"));
let error = Test::read_le(&mut data).expect_err("accepted bad data");
match error {
Error::Backtrace(bt) => {
assert!(matches!(*bt.error, Error::Io(..)));

match (&bt.frames[0], &bt.frames[1]) {
(BacktraceFrame::Message(m), BacktraceFrame::Custom(e)) => {
assert_eq!(m, "rewinding after a failure");
match e.downcast_ref::<binrw::Error>() {
Some(binrw::Error::AssertFail { pos, .. }) => assert_eq!(*pos, 0),
_ => panic!("unexpected error"),
}
}
_ => panic!("unexpected error frame layout"),
}
}
_ => panic!("expected backtrace"),
}
}

#[test]
fn no_seek_data_enum() {
use binrw::{
error::BacktraceFrame,
io::{Cursor, NoSeek},
BinRead,
};

#[derive(BinRead, Debug)]
enum Test {
#[br(magic(0u8))]
A(#[br(assert(self_0 == 1))] u32),
#[br(magic(1u8))]
B(#[br(assert(self_0 == 2))] u32),
}

let mut data = NoSeek::new(Cursor::new(b"\0\0\0\0\0"));
let error = Test::read_le(&mut data).expect_err("accepted bad data");

match error {
Error::Backtrace(bt) => {
assert!(matches!(*bt.error, Error::Io(..)));

match (&bt.frames[0], &bt.frames[1]) {
(BacktraceFrame::Message(m), BacktraceFrame::Custom(e)) => {
assert_eq!(m, "rewinding after a failure");
match e.downcast_ref::<binrw::Error>() {
Some(binrw::Error::AssertFail { pos, .. }) => assert_eq!(*pos, 0),
e => panic!("unexpected error {:?}", e),
}
}
_ => panic!("unexpected error frame layout"),
}
}
_ => panic!("expected backtrace"),
}
}

#[test]
fn no_seek_unit_enum() {
use binrw::{
error::BacktraceFrame,
io::{Cursor, NoSeek},
BinRead,
};

#[derive(BinRead, Debug)]
#[br(big, repr = u32)]
enum Test {
A = 1,
B = 2,
C = 3,
}

let mut data = NoSeek::new(Cursor::new(b"\0\0\0\0"));
let error = Test::read_le(&mut data).expect_err("accepted bad data");

match error {
Error::Backtrace(bt) => {
assert!(matches!(*bt.error, Error::Io(..)));

match (&bt.frames[0], &bt.frames[1]) {
(BacktraceFrame::Message(m), BacktraceFrame::Custom(e)) => {
assert_eq!(m, "rewinding after a failure");
match e.downcast_ref::<binrw::Error>() {
Some(binrw::Error::NoVariantMatch { pos }) => assert_eq!(*pos, 0),
e => panic!("unexpected error {:?}", e),
}
}
_ => panic!("unexpected error frame layout"),
}
}
_ => panic!("expected backtrace"),
}
}

#[test]
fn show_backtrace() {
use alloc::borrow::Cow;
Expand Down
44 changes: 27 additions & 17 deletions binrw_derive/src/binrw/codegen/read_options.rs
Expand Up @@ -8,7 +8,8 @@ use crate::{
codegen::{
get_endian,
sanitization::{
ARGS, ASSERT_MAGIC, MAP_READER_TYPE_HINT, OPT, POS, READER, SEEK_FROM, SEEK_TRAIT,
ARGS, ASSERT_MAGIC, MAP_READER_TYPE_HINT, OPT, POS, READER, RESTORE_POSITION,
SEEK_TRAIT,
},
},
parser::{Input, Magic, Map},
Expand All @@ -23,36 +24,45 @@ use syn::{spanned::Spanned, Ident};

pub(crate) fn generate(input: &Input, derive_input: &syn::DeriveInput) -> TokenStream {
let name = Some(&derive_input.ident);
let inner = match input.map() {
let (inner, needs_rewind) = match input.map() {
Map::None => match input {
Input::UnitStruct(_) => generate_unit_struct(input, name, None),
Input::Struct(s) => generate_struct(input, name, s),
Input::Enum(e) => generate_data_enum(input, name, e),
Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e),
Input::UnitStruct(_) => (generate_unit_struct(input, name, None), false),
Input::Struct(s) => (generate_struct(input, name, s), true),
Input::Enum(e) => (generate_data_enum(input, name, e), false),
Input::UnitOnlyEnum(e) => (
generate_unit_enum(input, name, e),
e.map.as_repr().is_some(),
),
},
Map::Try(map) => map::generate_try_map(input, name, map),
Map::Map(map) => map::generate_map(input, name, map),
Map::Try(map) => (map::generate_try_map(input, name, map), true),
Map::Map(map) => (map::generate_map(input, name, map), true),
Map::Repr(ty) => match input {
Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e),
_ => map::generate_try_map(
input,
name,
&quote! { <#ty as core::convert::TryInto<_>>::try_into },
Input::UnitOnlyEnum(e) => (generate_unit_enum(input, name, e), true),
_ => (
map::generate_try_map(
input,
name,
&quote! { <#ty as core::convert::TryInto<_>>::try_into },
),
true,
),
},
};

let reader_var = input.stream_ident_or(READER);

let rewind = (needs_rewind || input.magic().is_some()).then(|| {
quote! {
.or_else(#RESTORE_POSITION::<binrw::Error, _, _>(#reader_var, #POS))
}
});

quote! {
let #reader_var = #READER;
let #POS = #SEEK_TRAIT::stream_position(#reader_var)?;
(|| {
#inner
})().or_else(|error| {
#SEEK_TRAIT::seek(#reader_var, #SEEK_FROM::Start(#POS))?;
Err(error)
})
})()#rewind
}
}

Expand Down
37 changes: 17 additions & 20 deletions binrw_derive/src/binrw/codegen/read_options/enum.rs
Expand Up @@ -6,8 +6,8 @@ use crate::binrw::{
codegen::{
get_assertions,
sanitization::{
BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, POS, READER, READ_METHOD, SEEK_FROM,
SEEK_TRAIT, TEMP, WITH_CONTEXT,
BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, POS, READER, READ_METHOD,
RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT,
},
},
parser::{Enum, EnumErrorMode, EnumVariant, Input, UnitEnumField, UnitOnlyEnum},
Expand Down Expand Up @@ -128,22 +128,19 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField]
};

quote! {
let #TEMP = (|| {
match (|| {
#body
})();

if #TEMP.is_ok() {
return #TEMP;
})() {
v @ Ok(_) => return v,
Err(#TEMP) => { #RESTORE_POSITION_VARIANT(#reader_var, #POS, #TEMP)?; }
}

#SEEK_TRAIT::seek(#reader_var, #SEEK_FROM::Start(#POS))?;
}
});

let return_error = quote! {
Err(#BIN_ERROR::NoVariantMatch {
pos: #POS
})
pos: #POS
})
};

quote! {
Expand Down Expand Up @@ -194,22 +191,22 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum)
let handle_error = if return_all_errors {
let name = variant.ident().to_string();
quote! {
#ERROR_BASKET.push((#name, #TEMP.err().unwrap()));
#ERROR_BASKET.push((#name, #TEMP));
}
} else {
TokenStream::new()
};

quote! {
let #TEMP = (|| {
match (|| {
#body
})();

if #TEMP.is_ok() {
return #TEMP;
} else {
#handle_error
#SEEK_TRAIT::seek(#reader_var, #SEEK_FROM::Start(#POS))?;
})() {
ok @ Ok(_) => return ok,
Err(error) => {
#RESTORE_POSITION_VARIANT(#reader_var, #POS, error).map(|#TEMP| {
#handle_error
})?;
}
}
}
});
Expand Down
2 changes: 2 additions & 0 deletions binrw_derive/src/binrw/codegen/sanitization.rs
Expand Up @@ -57,6 +57,8 @@ ident_str! {
pub(crate) WRITE_MAP_INPUT_TYPE_HINT = from_crate!(__private::write_map_fn_input_type_hint);
pub(crate) WRITE_FN_MAP_OUTPUT_TYPE_HINT = from_crate!(__private::write_fn_map_output_type_hint);
pub(crate) WRITE_FN_TRY_MAP_OUTPUT_TYPE_HINT = from_crate!(__private::write_fn_try_map_output_type_hint);
pub(crate) RESTORE_POSITION = from_crate!(__private::restore_position);
pub(crate) RESTORE_POSITION_VARIANT = from_crate!(__private::restore_position_variant);
pub(crate) WRITE_ZEROES = from_crate!(__private::write_zeroes);
pub(crate) ARGS_MACRO = from_crate!(args);
pub(crate) META_ENDIAN_KIND = from_crate!(meta::EndianKind);
Expand Down

0 comments on commit 6d1bdc3

Please sign in to comment.