Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow self in non-unit struct and enum top-level assert #219

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions binrw/doc/attribute.md
Expand Up @@ -502,6 +502,15 @@ Any <span class="brw">(earlier only, when reading)</span><span class="br">earlie
field or [import](#arguments) can be referenced by expressions
in the directive.

<div class="br">

For `#[br]`, when using `map`, a non-unit `struct`, or an `enum`, a special variable
named `self` can be referenced by expressions in the directive. It contains the result
of the `map` function or the result of constructing the `struct` or `enum`. Note that
you cannot refer to the `enum` fields directly, as an `enum` variant is not its own type.

</div>

## Examples

### Formatted error
Expand Down
37 changes: 37 additions & 0 deletions binrw/tests/derive/enum.rs
Expand Up @@ -28,6 +28,43 @@ fn enum_assert() {
Test::read_le(&mut Cursor::new(b"\0\0\x01")).expect_err("accepted bad data");
}

#[test]
fn enum_assert_with_self() {
#[derive(BinRead, Debug, PartialEq)]
#[br(assert(self.verify()))]
enum Test {
A {
a: u8,
b: u8,
},
#[br(assert(self.verify_only_b()))]
B {
a: i16,
b: u8,
},
}

impl Test {
fn verify(&self) -> bool {
match self {
Test::A { b, .. } => *b == 1,
Test::B { a, b } => *a == -1 && *b == 1,
}
}

fn verify_only_b(&self) -> bool {
matches!(self, Test::B { .. })
}
}

assert_eq!(
Test::read_le(&mut Cursor::new(b"\xff\xff\x01")).unwrap(),
Test::B { a: -1, b: 1 }
);
Test::read_le(&mut Cursor::new(b"\xff\xff\0")).expect_err("accepted bad data");
Test::read_le(&mut Cursor::new(b"\0\0\x01")).expect_err("accepted bad data");
}

#[test]
fn enum_non_copy_args() {
#[derive(BinRead, Debug)]
Expand Down
24 changes: 24 additions & 0 deletions binrw/tests/derive/map_args.rs
Expand Up @@ -48,3 +48,27 @@ fn map_field_assert_access_fields() {

Test::read(&mut Cursor::new(b"a")).unwrap();
}

#[test]
#[should_panic]
fn map_top_assert_legacy_this() {
#[derive(BinRead, Debug, Eq, PartialEq)]
#[br(assert(this.x == 2), map(|_: u8| Test { x: 3 }))]
struct Test {
x: u8,
}

Test::read(&mut Cursor::new(b"a")).unwrap();
}

#[test]
#[should_panic]
fn map_top_assert_via_self() {
#[derive(BinRead, Debug, Eq, PartialEq)]
#[br(assert(self.x == 2), map(|_: u8| Test { x: 3 }))]
struct Test {
x: u8,
}

Test::read(&mut Cursor::new(b"a")).unwrap();
}
52 changes: 52 additions & 0 deletions binrw/tests/derive/struct.rs
Expand Up @@ -704,6 +704,58 @@ fn reader_var() {
);
}

#[test]
fn top_level_assert_has_self() {
#[allow(dead_code)]
#[derive(BinRead, Debug)]
#[br(assert(self.verify(), "verify failed"))]
struct Test {
a: u8,
b: u8,
}

impl Test {
fn verify(&self) -> bool {
self.a == self.b
}
}

let mut data = Cursor::new(b"\x01\x01");
Test::read_le(&mut data).expect("a == b passed");
let mut data = Cursor::new(b"\x01\x02");
let err = Test::read_le(&mut data).expect_err("a == b failed");
assert!(matches!(err, binrw::Error::AssertFail {
message,
..
} if message == "verify failed"));
}

#[test]
fn top_level_assert_self_weird() {
#[allow(dead_code)]
#[derive(BinRead, Debug)]
#[br(assert(Test::verify(&self), "verify failed"))]
struct Test {
a: u8,
b: u8,
}

impl Test {
fn verify(&self) -> bool {
self.a == self.b
}
}

let mut data = Cursor::new(b"\x01\x01");
Test::read_le(&mut data).expect("a == b passed");
let mut data = Cursor::new(b"\x01\x02");
let err = Test::read_le(&mut data).expect_err("a == b failed");
assert!(matches!(err, binrw::Error::AssertFail {
message,
..
} if message == "verify failed"));
}

#[test]
fn rewind_on_assert() {
#[allow(dead_code)]
Expand Down
1 change: 1 addition & 0 deletions binrw_derive/src/binrw/codegen/mod.rs
Expand Up @@ -204,6 +204,7 @@ fn get_assertions(assertions: &[Assert]) -> impl Iterator<Item = TokenStream> +
kw_span,
condition,
consequent,
..
}| {
let error_fn = match &consequent {
Some(AssertionError::Message(message)) => {
Expand Down
13 changes: 5 additions & 8 deletions binrw_derive/src/binrw/codegen/read_options/enum.rs
Expand Up @@ -3,12 +3,9 @@ use super::{
PreludeGenerator,
};
use crate::binrw::{
codegen::{
get_assertions,
sanitization::{
BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, POS, READER, READ_METHOD,
RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT,
},
codegen::sanitization::{
BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, POS, READER, READ_METHOD,
RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT,
},
parser::{Enum, EnumErrorMode, EnumVariant, Input, UnitEnumField, UnitOnlyEnum},
};
Expand Down Expand Up @@ -228,8 +225,8 @@ fn generate_variant_impl(en: &Enum, variant: &EnumVariant) -> TokenStream {
None,
Some(&format!("{}::{}", en.ident.as_ref().unwrap(), &ident)),
)
.add_assertions(get_assertions(&en.assertions))
.return_value(Some(ident))
.initialize_value_with_assertions(Some(ident), &en.assertions)
.return_value()
.finish(),

EnumVariant::Unit(options) => generate_unit_struct(&input, None, Some(&options.ident)),
Expand Down
48 changes: 40 additions & 8 deletions binrw_derive/src/binrw/codegen/read_options/struct.rs
@@ -1,6 +1,7 @@
use super::{get_magic, PreludeGenerator};
#[cfg(feature = "verbose-backtrace")]
use crate::binrw::backtrace::BacktraceFrame;
use crate::binrw::parser::Assert;
use crate::{
binrw::{
codegen::{
Expand Down Expand Up @@ -37,8 +38,8 @@ pub(super) fn generate_unit_struct(
pub(super) fn generate_struct(input: &Input, name: Option<&Ident>, st: &Struct) -> TokenStream {
StructGenerator::new(input, st)
.read_fields(name, None)
.add_assertions(core::iter::empty())
.return_value(None)
.initialize_value_with_assertions(None, &[])
.return_value()
.finish()
}

Expand All @@ -61,11 +62,31 @@ impl<'input> StructGenerator<'input> {
self.out
}

pub(super) fn add_assertions(
mut self,
extra_assertions: impl Iterator<Item = TokenStream>,
pub(super) fn initialize_value_with_assertions(
self,
variant_ident: Option<&Ident>,
extra_assertions: &[Assert],
) -> Self {
let assertions = get_assertions(&self.st.assertions).chain(extra_assertions);
if self.has_self_assertions(extra_assertions) {
self.init_value(variant_ident)
.add_assertions(extra_assertions)
} else {
self.add_assertions(extra_assertions)
.init_value(variant_ident)
}
}

fn has_self_assertions(&self, extra_assertions: &[Assert]) -> bool {
self.st
.assertions
.iter()
.chain(extra_assertions)
.any(|assert| assert.condition_uses_self)
}

fn add_assertions(mut self, extra_assertions: &[Assert]) -> Self {
let assertions =
get_assertions(&self.st.assertions).chain(get_assertions(extra_assertions));
let head = self.out;
self.out = quote! {
#head
Expand Down Expand Up @@ -102,7 +123,7 @@ impl<'input> StructGenerator<'input> {
self
}

pub(super) fn return_value(mut self, variant_ident: Option<&Ident>) -> Self {
fn init_value(mut self, variant_ident: Option<&Ident>) -> Self {
let out_names = self.st.iter_permanent_idents();
let return_type = get_return_type(variant_ident);
let return_value = if self.st.is_tuple() {
Expand All @@ -114,7 +135,18 @@ impl<'input> StructGenerator<'input> {
let head = self.out;
self.out = quote! {
#head
Ok(#return_value)
let this = #return_value;
};

self
}

pub(super) fn return_value(mut self) -> Self {
let head = self.out;

self.out = quote! {
#head
Ok(this)
};

self
Expand Down
33 changes: 28 additions & 5 deletions binrw_derive/src/binrw/parser/types/assert.rs
@@ -1,6 +1,7 @@
use crate::{binrw::parser::attrs, meta_types::KeywordToken};
use proc_macro2::{Span, TokenStream};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, ToTokens};
use syn::fold::Fold;
use syn::{parse::Parse, spanned::Spanned, token::Token, Expr, ExprLit, Lit};

#[derive(Debug, Clone)]
Expand All @@ -13,6 +14,9 @@ pub(crate) enum Error {
pub(crate) struct Assert {
pub(crate) kw_span: Span,
pub(crate) condition: TokenStream,
/// `true` if the condition was written with `self`, in the [`condition`] it is replaced with
/// `this`. This enables backwards compatibility with asserts that did not use `self`.
pub(crate) condition_uses_self: bool,
pub(crate) consequent: Option<Error>,
}

Expand All @@ -23,9 +27,7 @@ impl<K: Parse + Spanned + Token> TryFrom<attrs::AssertLike<K>> for Assert {
let kw_span = value.keyword_span();
let mut args = value.fields.iter();

let condition = if let Some(cond) = args.next() {
cond.into_token_stream()
} else {
let Some(cond_expr) = args.next() else {
return Err(Self::Error::new(
kw_span,
format!(
Expand All @@ -35,6 +37,11 @@ impl<K: Parse + Spanned + Token> TryFrom<attrs::AssertLike<K>> for Assert {
));
};

// ignores any alternative declaration of `self` in the condition, but asserts should be
// simple so that shouldn't be a problem
let mut self_replacer = ReplaceSelfWithThis { uses_self: false };
let cond_expr = self_replacer.fold_expr(cond_expr.clone());

let consequent = match args.next() {
Some(Expr::Lit(ExprLit {
lit: Lit::Str(message),
Expand All @@ -52,8 +59,24 @@ impl<K: Parse + Spanned + Token> TryFrom<attrs::AssertLike<K>> for Assert {

Ok(Self {
kw_span,
condition,
condition: cond_expr.into_token_stream(),
condition_uses_self: self_replacer.uses_self,
consequent,
})
}
}

struct ReplaceSelfWithThis {
uses_self: bool,
}

impl Fold for ReplaceSelfWithThis {
fn fold_ident(&mut self, i: Ident) -> Ident {
if i == "self" {
self.uses_self = true;
Ident::new("this", i.span())
} else {
i
}
}
}