diff --git a/ruby/ext/google/protobuf_c/message.c b/ruby/ext/google/protobuf_c/message.c index d07eba760d23..6f6ba2d34f10 100644 --- a/ruby/ext/google/protobuf_c/message.c +++ b/ruby/ext/google/protobuf_c/message.c @@ -918,13 +918,35 @@ static VALUE Message_index_set(VALUE _self, VALUE field_name, VALUE value) { /* * call-seq: - * MessageClass.decode(data) => message + * MessageClass.decode(data, options) => message * * Decodes the given data (as a string containing bytes in protocol buffers wire * format) under the interpretration given by this message class's definition * and returns a message object with the corresponding field values. + * @param options [Hash] options for the decoder + * max_recursion_depth: set to maximum decoding depth for message (default is 64) */ -static VALUE Message_decode(VALUE klass, VALUE data) { +static VALUE Message_decode(int argc, VALUE* argv, VALUE klass) { + VALUE data = argv[0]; + int options = 0; + + if (argc < 1 || argc > 2) { + rb_raise(rb_eArgError, "Expected 1 or 2 arguments."); + } + + if (argc == 2) { + VALUE hash_args = argv[1]; + if (TYPE(hash_args) != T_HASH) { + rb_raise(rb_eArgError, "Expected hash arguments."); + } + + VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth"))); + + if (depth != Qnil && TYPE(depth) == T_FIXNUM) { + options = FIX2INT(depth) << 16; + } + } + if (TYPE(data) != T_STRING) { rb_raise(rb_eArgError, "Expected string for binary protobuf data."); } @@ -932,9 +954,11 @@ static VALUE Message_decode(VALUE klass, VALUE data) { VALUE msg_rb = initialize_rb_class_with_no_args(klass); Message* msg = ruby_to_Message(msg_rb); - if (!upb_decode(RSTRING_PTR(data), RSTRING_LEN(data), (upb_msg*)msg->msg, - upb_msgdef_layout(msg->msgdef), - Arena_get(msg->arena))) { + if (!_upb_decode(RSTRING_PTR(data), RSTRING_LEN(data), (upb_msg*)msg->msg, + upb_msgdef_layout(msg->msgdef), + NULL, + options, + Arena_get(msg->arena))) { rb_raise(cParseError, "Error occurred during parsing"); } @@ -1005,24 +1029,43 @@ static VALUE Message_decode_json(int argc, VALUE* argv, VALUE klass) { /* * call-seq: - * MessageClass.encode(msg) => bytes + * MessageClass.encode(msg, options) => bytes * * Encodes the given message object to its serialized form in protocol buffers * wire format. + * @param options [Hash] options for the encoder + * max_recursion_depth: set to maximum encoding depth for message (default is 64) */ -static VALUE Message_encode(VALUE klass, VALUE msg_rb) { - Message* msg = ruby_to_Message(msg_rb); +static VALUE Message_encode(int argc, VALUE* argv, VALUE klass) { + Message* msg = ruby_to_Message(argv[0]); + int options = 0; const char *data; size_t size; - if (CLASS_OF(msg_rb) != klass) { + if (CLASS_OF(argv[0]) != klass) { rb_raise(rb_eArgError, "Message of wrong type."); } + if (argc < 1 || argc > 2) { + rb_raise(rb_eArgError, "Expected 1 or 2 arguments."); + } + + if (argc == 2) { + VALUE hash_args = argv[1]; + if (TYPE(hash_args) != T_HASH) { + rb_raise(rb_eArgError, "Expected hash arguments."); + } + VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth"))); + + if (depth != Qnil && TYPE(depth) == T_FIXNUM) { + options = FIX2INT(depth) << 16; + } + } + upb_arena *arena = upb_arena_new(); - data = upb_encode(msg->msg, upb_msgdef_layout(msg->msgdef), arena, - &size); + data = upb_encode_ex(msg->msg, upb_msgdef_layout(msg->msgdef), + options, arena, &size); if (data) { VALUE ret = rb_str_new(data, size); @@ -1149,8 +1192,8 @@ VALUE build_class_from_descriptor(VALUE descriptor) { rb_define_method(klass, "to_s", Message_inspect, 0); rb_define_method(klass, "[]", Message_index, 1); rb_define_method(klass, "[]=", Message_index_set, 2); - rb_define_singleton_method(klass, "decode", Message_decode, 1); - rb_define_singleton_method(klass, "encode", Message_encode, 1); + rb_define_singleton_method(klass, "decode", Message_decode, -1); + rb_define_singleton_method(klass, "encode", Message_encode, -1); rb_define_singleton_method(klass, "decode_json", Message_decode_json, -1); rb_define_singleton_method(klass, "encode_json", Message_encode_json, -1); rb_define_singleton_method(klass, "descriptor", Message_descriptor, 0); diff --git a/ruby/lib/google/protobuf.rb b/ruby/lib/google/protobuf.rb index f939a4c7dcd7..b7a671105158 100644 --- a/ruby/lib/google/protobuf.rb +++ b/ruby/lib/google/protobuf.rb @@ -59,16 +59,16 @@ class TypeError < ::TypeError; end module Google module Protobuf - def self.encode(msg) - msg.to_proto + def self.encode(msg, options = {}) + msg.to_proto(options) end def self.encode_json(msg, options = {}) msg.to_json(options) end - def self.decode(klass, proto) - klass.decode(proto) + def self.decode(klass, proto, options = {}) + klass.decode(proto, options) end def self.decode_json(klass, json, options = {}) diff --git a/ruby/lib/google/protobuf/message_exts.rb b/ruby/lib/google/protobuf/message_exts.rb index f432f89fed0c..660852172896 100644 --- a/ruby/lib/google/protobuf/message_exts.rb +++ b/ruby/lib/google/protobuf/message_exts.rb @@ -44,8 +44,8 @@ def to_json(options = {}) self.class.encode_json(self, options) end - def to_proto - self.class.encode(self) + def to_proto(options = {}) + self.class.encode(self, options) end end diff --git a/ruby/tests/encode_decode_test.rb b/ruby/tests/encode_decode_test.rb index 429ac4332216..9513cc37d9dc 100755 --- a/ruby/tests/encode_decode_test.rb +++ b/ruby/tests/encode_decode_test.rb @@ -101,4 +101,55 @@ def test_json_name assert_match json, "{\"CustomJsonName\":42}" end + def test_decode_depth_limit + msg = A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + ) + ) + ) + ) + ) + ) + msg_encoded = A::B::C::TestMessage.encode(msg) + msg_out = A::B::C::TestMessage.decode(msg_encoded) + assert_match msg.to_json, msg_out.to_json + + assert_raise Google::Protobuf::ParseError do + A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 4 }) + end + + msg_out = A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 5 }) + assert_match msg.to_json, msg_out.to_json + end + + def test_encode_depth_limit + msg = A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + ) + ) + ) + ) + ) + ) + msg_encoded = A::B::C::TestMessage.encode(msg) + msg_out = A::B::C::TestMessage.decode(msg_encoded) + assert_match msg.to_json, msg_out.to_json + + assert_raise RuntimeError do + A::B::C::TestMessage.encode(msg, { max_recursion_depth: 5 }) + end + + msg_encoded = A::B::C::TestMessage.encode(msg, { max_recursion_depth: 6 }) + msg_out = A::B::C::TestMessage.decode(msg_encoded) + assert_match msg.to_json, msg_out.to_json + end + end