diff --git a/ruby/ext/google/protobuf_c/message.c b/ruby/ext/google/protobuf_c/message.c index 7feee75db87b..8d25b7928c7e 100644 --- a/ruby/ext/google/protobuf_c/message.c +++ b/ruby/ext/google/protobuf_c/message.c @@ -953,13 +953,35 @@ static VALUE Message_index_set(VALUE _self, VALUE field_name, VALUE value) { /* * call-seq: - * MessageClass.decode(data) => message + * MessageClass.decode(data, options) => message * * Decodes the given data (as a string containing bytes in protocol buffers wire * format) under the interpretration given by this message class's definition * and returns a message object with the corresponding field values. + * @param options [Hash] options for the decoder + * max_recursion_depth: set to maximum decoding depth for message (default is 64) */ -static VALUE Message_decode(VALUE klass, VALUE data) { +static VALUE Message_decode(int argc, VALUE* argv, VALUE klass) { + VALUE data = argv[0]; + int options = 0; + + if (argc < 1 || argc > 2) { + rb_raise(rb_eArgError, "Expected 1 or 2 arguments."); + } + + if (argc == 2) { + VALUE hash_args = argv[1]; + if (TYPE(hash_args) != T_HASH) { + rb_raise(rb_eArgError, "Expected hash arguments."); + } + + VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth"))); + + if (depth != Qnil && TYPE(depth) == T_FIXNUM) { + options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth)); + } + } + if (TYPE(data) != T_STRING) { rb_raise(rb_eArgError, "Expected string for binary protobuf data."); } @@ -969,7 +991,7 @@ static VALUE Message_decode(VALUE klass, VALUE data) { upb_DecodeStatus status = upb_Decode( RSTRING_PTR(data), RSTRING_LEN(data), (upb_Message*)msg->msg, - upb_MessageDef_MiniTable(msg->msgdef), NULL, 0, Arena_get(msg->arena)); + upb_MessageDef_MiniTable(msg->msgdef), NULL, options, Arena_get(msg->arena)); if (status != kUpb_DecodeStatus_Ok) { rb_raise(cParseError, "Error occurred during parsing"); @@ -1043,24 +1065,43 @@ static VALUE Message_decode_json(int argc, VALUE* argv, VALUE klass) { /* * call-seq: - * MessageClass.encode(msg) => bytes + * MessageClass.encode(msg, options) => bytes * * Encodes the given message object to its serialized form in protocol buffers * wire format. + * @param options [Hash] options for the encoder + * max_recursion_depth: set to maximum encoding depth for message (default is 64) */ -static VALUE Message_encode(VALUE klass, VALUE msg_rb) { - Message* msg = ruby_to_Message(msg_rb); +static VALUE Message_encode(int argc, VALUE* argv, VALUE klass) { + Message* msg = ruby_to_Message(argv[0]); + int options = 0; const char* data; size_t size; - if (CLASS_OF(msg_rb) != klass) { + if (CLASS_OF(argv[0]) != klass) { rb_raise(rb_eArgError, "Message of wrong type."); } - upb_Arena* arena = upb_Arena_New(); + if (argc < 1 || argc > 2) { + rb_raise(rb_eArgError, "Expected 1 or 2 arguments."); + } - data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef), 0, arena, - &size); + if (argc == 2) { + VALUE hash_args = argv[1]; + if (TYPE(hash_args) != T_HASH) { + rb_raise(rb_eArgError, "Expected hash arguments."); + } + VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth"))); + + if (depth != Qnil && TYPE(depth) == T_FIXNUM) { + options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth)); + } + } + + upb_Arena *arena = upb_Arena_New(); + + data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef), + options, arena, &size); if (data) { VALUE ret = rb_str_new(data, size); @@ -1186,8 +1227,8 @@ VALUE build_class_from_descriptor(VALUE descriptor) { rb_define_method(klass, "to_s", Message_inspect, 0); rb_define_method(klass, "[]", Message_index, 1); rb_define_method(klass, "[]=", Message_index_set, 2); - rb_define_singleton_method(klass, "decode", Message_decode, 1); - rb_define_singleton_method(klass, "encode", Message_encode, 1); + rb_define_singleton_method(klass, "decode", Message_decode, -1); + rb_define_singleton_method(klass, "encode", Message_encode, -1); rb_define_singleton_method(klass, "decode_json", Message_decode_json, -1); rb_define_singleton_method(klass, "encode_json", Message_encode_json, -1); rb_define_singleton_method(klass, "descriptor", Message_descriptor, 0); diff --git a/ruby/lib/google/protobuf.rb b/ruby/lib/google/protobuf.rb index f939a4c7dcd7..b7a671105158 100644 --- a/ruby/lib/google/protobuf.rb +++ b/ruby/lib/google/protobuf.rb @@ -59,16 +59,16 @@ class TypeError < ::TypeError; end module Google module Protobuf - def self.encode(msg) - msg.to_proto + def self.encode(msg, options = {}) + msg.to_proto(options) end def self.encode_json(msg, options = {}) msg.to_json(options) end - def self.decode(klass, proto) - klass.decode(proto) + def self.decode(klass, proto, options = {}) + klass.decode(proto, options) end def self.decode_json(klass, json, options = {}) diff --git a/ruby/lib/google/protobuf/message_exts.rb b/ruby/lib/google/protobuf/message_exts.rb index f432f89fed0c..660852172896 100644 --- a/ruby/lib/google/protobuf/message_exts.rb +++ b/ruby/lib/google/protobuf/message_exts.rb @@ -44,8 +44,8 @@ def to_json(options = {}) self.class.encode_json(self, options) end - def to_proto - self.class.encode(self) + def to_proto(options = {}) + self.class.encode(self, options) end end diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java index f7379b148edf..b5e4903e2e61 100644 --- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java +++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java @@ -389,7 +389,7 @@ protected IRubyObject deepCopy(ThreadContext context) { return newMap; } - protected List build(ThreadContext context, RubyDescriptor descriptor, int depth) { + protected List build(ThreadContext context, RubyDescriptor descriptor, int depth, int maxRecursionDepth) { List list = new ArrayList(); RubyClass rubyClass = (RubyClass) descriptor.msgclass(context); FieldDescriptor keyField = descriptor.getField("key"); @@ -398,7 +398,7 @@ protected List build(ThreadContext context, RubyDescriptor descr RubyMessage mapMessage = (RubyMessage) rubyClass.newInstance(context, Block.NULL_BLOCK); mapMessage.setField(context, keyField, key); mapMessage.setField(context, valueField, table.get(key)); - list.add(mapMessage.build(context, depth + 1)); + list.add(mapMessage.build(context, depth + 1, maxRecursionDepth)); } return list; } diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java index cf59f625972c..2ba132e62831 100644 --- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java +++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java @@ -39,6 +39,7 @@ import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.Descriptors.OneofDescriptor; import com.google.protobuf.ByteString; +import com.google.protobuf.CodedInputStream; import com.google.protobuf.DynamicMessage; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -461,35 +462,63 @@ public static IRubyObject getDescriptor(ThreadContext context, IRubyObject recv) /* * call-seq: - * MessageClass.encode(msg) => bytes + * MessageClass.encode(msg, options = {}) => bytes * * Encodes the given message object to its serialized form in protocol buffers * wire format. + * @param options [Hash] options for the encoder + * max_recursion_depth: set to maximum encoding depth for message (default is 64) */ - @JRubyMethod(meta = true) - public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject value) { - if (recv != value.getMetaClass()) { - throw context.runtime.newArgumentError("Tried to encode a " + value.getMetaClass() + " message with " + recv); + @JRubyMethod(required = 1, optional = 1, meta = true) + public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject[] args) { + if (recv != args[0].getMetaClass()) { + throw context.runtime.newArgumentError("Tried to encode a " + args[0].getMetaClass() + " message with " + recv); } - RubyMessage message = (RubyMessage) value; - return context.runtime.newString(new ByteList(message.build(context).toByteArray())); + RubyMessage message = (RubyMessage) args[0]; + int maxRecursionDepthInt = SINK_MAXIMUM_NESTING; + + if (args.length > 1) { + RubyHash options = (RubyHash) args[1]; + IRubyObject maxRecursionDepth = options.fastARef(context.runtime.newSymbol("max_recursion_depth")); + + if (maxRecursionDepth != null) { + maxRecursionDepthInt = ((RubyNumeric) maxRecursionDepth).getIntValue(); + } + } + return context.runtime.newString(new ByteList(message.build(context, 0, maxRecursionDepthInt).toByteArray())); } /* * call-seq: - * MessageClass.decode(data) => message + * MessageClass.decode(data, options = {}) => message * * Decodes the given data (as a string containing bytes in protocol buffers wire * format) under the interpretation given by this message class's definition * and returns a message object with the corresponding field values. + * @param options [Hash] options for the decoder + * max_recursion_depth: set to maximum decoding depth for message (default is 100) */ - @JRubyMethod(meta = true) - public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject data) { + @JRubyMethod(required = 1, optional = 1, meta = true) + public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject[] args) { + IRubyObject data = args[0]; byte[] bin = data.convertToString().getBytes(); + CodedInputStream input = CodedInputStream.newInstance(bin); RubyMessage ret = (RubyMessage) ((RubyClass) recv).newInstance(context, Block.NULL_BLOCK); + + if (args.length == 2) { + if (!(args[1] instanceof RubyHash)) { + throw context.runtime.newArgumentError("Expected hash arguments."); + } + + IRubyObject maxRecursionDepth = ((RubyHash) args[1]).fastARef(context.runtime.newSymbol("max_recursion_depth")); + if (maxRecursionDepth != null) { + input.setRecursionLimit(((RubyNumeric) maxRecursionDepth).getIntValue()); + } + } + try { - ret.builder.mergeFrom(bin); - } catch (InvalidProtocolBufferException e) { + ret.builder.mergeFrom(input); + } catch (Exception e) { throw RaiseException.from(context.runtime, (RubyClass) context.runtime.getClassFromPath("Google::Protobuf::ParseError"), e.getMessage()); } @@ -541,7 +570,7 @@ public static IRubyObject encodeJson(ThreadContext context, IRubyObject recv, IR printer = printer.usingTypeRegistry(JsonFormat.TypeRegistry.newBuilder().add(message.descriptor).build()); try { - result = printer.print(message.build(context)); + result = printer.print(message.build(context, 0, SINK_MAXIMUM_NESTING)); } catch (InvalidProtocolBufferException e) { throw runtime.newRuntimeError(e.getMessage()); } catch (IllegalArgumentException e) { @@ -635,12 +664,8 @@ public IRubyObject toHash(ThreadContext context) { return ret; } - protected DynamicMessage build(ThreadContext context) { - return build(context, 0); - } - - protected DynamicMessage build(ThreadContext context, int depth) { - if (depth > SINK_MAXIMUM_NESTING) { + protected DynamicMessage build(ThreadContext context, int depth, int maxRecursionDepth) { + if (depth >= maxRecursionDepth) { throw context.runtime.newRuntimeError("Maximum recursion depth exceeded during encoding."); } @@ -651,7 +676,7 @@ protected DynamicMessage build(ThreadContext context, int depth) { if (value instanceof RubyMap) { builder.clearField(fieldDescriptor); RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor); - for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) { + for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) { builder.addRepeatedField(fieldDescriptor, kv); } @@ -660,7 +685,7 @@ protected DynamicMessage build(ThreadContext context, int depth) { builder.clearField(fieldDescriptor); for (int i = 0; i < repeatedField.size(); i++) { - Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth, + Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth, maxRecursionDepth, /*isDefaultValueForBytes*/ false); builder.addRepeatedField(fieldDescriptor, item); } @@ -682,7 +707,7 @@ protected DynamicMessage build(ThreadContext context, int depth) { fieldDescriptor.getFullName().equals("google.protobuf.FieldDescriptorProto.default_value")) { isDefaultStringForBytes = true; } - builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, isDefaultStringForBytes)); + builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, maxRecursionDepth, isDefaultStringForBytes)); } } @@ -702,7 +727,7 @@ protected DynamicMessage build(ThreadContext context, int depth) { builder.clearField(fieldDescriptor); RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor); - for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) { + for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) { builder.addRepeatedField(fieldDescriptor, kv); } } @@ -814,7 +839,8 @@ private FieldDescriptor findField(ThreadContext context, IRubyObject fieldName, // convert a ruby object to protobuf type, skip type check since it is checked on the way in private Object convert(ThreadContext context, FieldDescriptor fieldDescriptor, - IRubyObject value, int depth, boolean isDefaultStringForBytes) { + IRubyObject value, int depth, int maxRecursionDepth, + boolean isDefaultStringForBytes) { Object val = null; switch (fieldDescriptor.getType()) { case INT32: @@ -855,7 +881,7 @@ private Object convert(ThreadContext context, } break; case MESSAGE: - val = ((RubyMessage) value).build(context, depth + 1); + val = ((RubyMessage) value).build(context, depth + 1, maxRecursionDepth); break; case ENUM: EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType(); @@ -1214,7 +1240,7 @@ private void validateMessageType(ThreadContext context, FieldDescriptor fieldDes private static final String CONST_SUFFIX = "_const"; private static final String HAS_PREFIX = "has_"; private static final String QUESTION_MARK = "?"; - private static final int SINK_MAXIMUM_NESTING = 63; + private static final int SINK_MAXIMUM_NESTING = 64; private Descriptor descriptor; private DynamicMessage.Builder builder; diff --git a/ruby/tests/encode_decode_test.rb b/ruby/tests/encode_decode_test.rb index 429ac4332216..9513cc37d9dc 100755 --- a/ruby/tests/encode_decode_test.rb +++ b/ruby/tests/encode_decode_test.rb @@ -101,4 +101,55 @@ def test_json_name assert_match json, "{\"CustomJsonName\":42}" end + def test_decode_depth_limit + msg = A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + ) + ) + ) + ) + ) + ) + msg_encoded = A::B::C::TestMessage.encode(msg) + msg_out = A::B::C::TestMessage.decode(msg_encoded) + assert_match msg.to_json, msg_out.to_json + + assert_raise Google::Protobuf::ParseError do + A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 4 }) + end + + msg_out = A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 5 }) + assert_match msg.to_json, msg_out.to_json + end + + def test_encode_depth_limit + msg = A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + optional_msg: A::B::C::TestMessage.new( + ) + ) + ) + ) + ) + ) + msg_encoded = A::B::C::TestMessage.encode(msg) + msg_out = A::B::C::TestMessage.decode(msg_encoded) + assert_match msg.to_json, msg_out.to_json + + assert_raise RuntimeError do + A::B::C::TestMessage.encode(msg, { max_recursion_depth: 5 }) + end + + msg_encoded = A::B::C::TestMessage.encode(msg, { max_recursion_depth: 6 }) + msg_out = A::B::C::TestMessage.decode(msg_encoded) + assert_match msg.to_json, msg_out.to_json + end + end