Skip to content

Commit

Permalink
Message.decode/encode: Add max_recursion_depth option
Browse files Browse the repository at this point in the history
This allows increasing the recursing depth from the default of 64, by
setting the "max_recursion_depth" to the desired integer value. This is
useful to encode or decode complex nested protobuf messages that otherwise
error out with a RuntimeError or "Error occurred during parsing".

Fixes #1493
  • Loading branch information
lfittl committed Nov 14, 2021
1 parent 0ca4c1a commit fab7498
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 19 deletions.
69 changes: 56 additions & 13 deletions ruby/ext/google/protobuf_c/message.c
Expand Up @@ -918,23 +918,47 @@ static VALUE Message_index_set(VALUE _self, VALUE field_name, VALUE value) {

/*
* call-seq:
* MessageClass.decode(data) => message
* MessageClass.decode(data, options) => message
*
* Decodes the given data (as a string containing bytes in protocol buffers wire
* format) under the interpretration given by this message class's definition
* and returns a message object with the corresponding field values.
* @param options [Hash] options for the decoder
* max_recursion_depth: set to maximum decoding depth for message (default is 64)
*/
static VALUE Message_decode(VALUE klass, VALUE data) {
static VALUE Message_decode(int argc, VALUE* argv, VALUE klass) {
VALUE data = argv[0];
int options = 0;

if (argc < 1 || argc > 2) {
rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
}

if (argc == 2) {
VALUE hash_args = argv[1];
if (TYPE(hash_args) != T_HASH) {
rb_raise(rb_eArgError, "Expected hash arguments.");
}

VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));

if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
options = FIX2INT(depth) << 16;
}
}

if (TYPE(data) != T_STRING) {
rb_raise(rb_eArgError, "Expected string for binary protobuf data.");
}

VALUE msg_rb = initialize_rb_class_with_no_args(klass);
Message* msg = ruby_to_Message(msg_rb);

if (!upb_decode(RSTRING_PTR(data), RSTRING_LEN(data), (upb_msg*)msg->msg,
upb_msgdef_layout(msg->msgdef),
Arena_get(msg->arena))) {
if (!_upb_decode(RSTRING_PTR(data), RSTRING_LEN(data), (upb_msg*)msg->msg,
upb_msgdef_layout(msg->msgdef),
NULL,
options,
Arena_get(msg->arena))) {
rb_raise(cParseError, "Error occurred during parsing");
}

Expand Down Expand Up @@ -1005,24 +1029,43 @@ static VALUE Message_decode_json(int argc, VALUE* argv, VALUE klass) {

/*
* call-seq:
* MessageClass.encode(msg) => bytes
* MessageClass.encode(msg, options) => bytes
*
* Encodes the given message object to its serialized form in protocol buffers
* wire format.
* @param options [Hash] options for the encoder
* max_recursion_depth: set to maximum encoding depth for message (default is 64)
*/
static VALUE Message_encode(VALUE klass, VALUE msg_rb) {
Message* msg = ruby_to_Message(msg_rb);
static VALUE Message_encode(int argc, VALUE* argv, VALUE klass) {
Message* msg = ruby_to_Message(argv[0]);
int options = 0;
const char *data;
size_t size;

if (CLASS_OF(msg_rb) != klass) {
if (CLASS_OF(argv[0]) != klass) {
rb_raise(rb_eArgError, "Message of wrong type.");
}

if (argc < 1 || argc > 2) {
rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
}

if (argc == 2) {
VALUE hash_args = argv[1];
if (TYPE(hash_args) != T_HASH) {
rb_raise(rb_eArgError, "Expected hash arguments.");
}
VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));

if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
options = FIX2INT(depth) << 16;
}
}

upb_arena *arena = upb_arena_new();

data = upb_encode(msg->msg, upb_msgdef_layout(msg->msgdef), arena,
&size);
data = upb_encode_ex(msg->msg, upb_msgdef_layout(msg->msgdef),
options, arena, &size);

if (data) {
VALUE ret = rb_str_new(data, size);
Expand Down Expand Up @@ -1149,8 +1192,8 @@ VALUE build_class_from_descriptor(VALUE descriptor) {
rb_define_method(klass, "to_s", Message_inspect, 0);
rb_define_method(klass, "[]", Message_index, 1);
rb_define_method(klass, "[]=", Message_index_set, 2);
rb_define_singleton_method(klass, "decode", Message_decode, 1);
rb_define_singleton_method(klass, "encode", Message_encode, 1);
rb_define_singleton_method(klass, "decode", Message_decode, -1);
rb_define_singleton_method(klass, "encode", Message_encode, -1);
rb_define_singleton_method(klass, "decode_json", Message_decode_json, -1);
rb_define_singleton_method(klass, "encode_json", Message_encode_json, -1);
rb_define_singleton_method(klass, "descriptor", Message_descriptor, 0);
Expand Down
8 changes: 4 additions & 4 deletions ruby/lib/google/protobuf.rb
Expand Up @@ -59,16 +59,16 @@ class TypeError < ::TypeError; end
module Google
module Protobuf

def self.encode(msg)
msg.to_proto
def self.encode(msg, options = {})
msg.to_proto(options)
end

def self.encode_json(msg, options = {})
msg.to_json(options)
end

def self.decode(klass, proto)
klass.decode(proto)
def self.decode(klass, proto, options = {})
klass.decode(proto, options)
end

def self.decode_json(klass, json, options = {})
Expand Down
4 changes: 2 additions & 2 deletions ruby/lib/google/protobuf/message_exts.rb
Expand Up @@ -44,8 +44,8 @@ def to_json(options = {})
self.class.encode_json(self, options)
end

def to_proto
self.class.encode(self)
def to_proto(options = {})
self.class.encode(self, options)
end

end
Expand Down
51 changes: 51 additions & 0 deletions ruby/tests/encode_decode_test.rb
Expand Up @@ -101,4 +101,55 @@ def test_json_name
assert_match json, "{\"CustomJsonName\":42}"
end

def test_decode_depth_limit
msg = A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
)
)
)
)
)
)
msg_encoded = A::B::C::TestMessage.encode(msg)
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json

assert_raise Google::Protobuf::ParseError do
A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 4 })
end

msg_out = A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 5 })
assert_match msg.to_json, msg_out.to_json
end

def test_encode_depth_limit
msg = A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
)
)
)
)
)
)
msg_encoded = A::B::C::TestMessage.encode(msg)
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json

assert_raise RuntimeError do
A::B::C::TestMessage.encode(msg, { max_recursion_depth: 5 })
end

msg_encoded = A::B::C::TestMessage.encode(msg, { max_recursion_depth: 6 })
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json
end

end

0 comments on commit fab7498

Please sign in to comment.