Skip to content

Commit

Permalink
[Ruby] Message.decode/encode: Add max_recursion_depth option (#9218)
Browse files Browse the repository at this point in the history
* Message.decode/encode: Add max_recursion_depth option

This allows increasing the recursing depth from the default of 64, by
setting the "max_recursion_depth" to the desired integer value. This is
useful to encode or decode complex nested protobuf messages that otherwise
error out with a RuntimeError or "Error occurred during parsing".

Fixes #1493

* Address review comments

Co-authored-by: Adam Cozzette <acozzette@google.com>
  • Loading branch information
lfittl and acozzette committed Feb 9, 2022
1 parent 4ed3941 commit fbe6ab2
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 46 deletions.
65 changes: 53 additions & 12 deletions ruby/ext/google/protobuf_c/message.c
Expand Up @@ -953,13 +953,35 @@ static VALUE Message_index_set(VALUE _self, VALUE field_name, VALUE value) {

/*
* call-seq:
* MessageClass.decode(data) => message
* MessageClass.decode(data, options) => message
*
* Decodes the given data (as a string containing bytes in protocol buffers wire
* format) under the interpretration given by this message class's definition
* and returns a message object with the corresponding field values.
* @param options [Hash] options for the decoder
* max_recursion_depth: set to maximum decoding depth for message (default is 64)
*/
static VALUE Message_decode(VALUE klass, VALUE data) {
static VALUE Message_decode(int argc, VALUE* argv, VALUE klass) {
VALUE data = argv[0];
int options = 0;

if (argc < 1 || argc > 2) {
rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
}

if (argc == 2) {
VALUE hash_args = argv[1];
if (TYPE(hash_args) != T_HASH) {
rb_raise(rb_eArgError, "Expected hash arguments.");
}

VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));

if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth));
}
}

if (TYPE(data) != T_STRING) {
rb_raise(rb_eArgError, "Expected string for binary protobuf data.");
}
Expand All @@ -969,7 +991,7 @@ static VALUE Message_decode(VALUE klass, VALUE data) {

upb_DecodeStatus status = upb_Decode(
RSTRING_PTR(data), RSTRING_LEN(data), (upb_Message*)msg->msg,
upb_MessageDef_MiniTable(msg->msgdef), NULL, 0, Arena_get(msg->arena));
upb_MessageDef_MiniTable(msg->msgdef), NULL, options, Arena_get(msg->arena));

if (status != kUpb_DecodeStatus_Ok) {
rb_raise(cParseError, "Error occurred during parsing");
Expand Down Expand Up @@ -1043,24 +1065,43 @@ static VALUE Message_decode_json(int argc, VALUE* argv, VALUE klass) {

/*
* call-seq:
* MessageClass.encode(msg) => bytes
* MessageClass.encode(msg, options) => bytes
*
* Encodes the given message object to its serialized form in protocol buffers
* wire format.
* @param options [Hash] options for the encoder
* max_recursion_depth: set to maximum encoding depth for message (default is 64)
*/
static VALUE Message_encode(VALUE klass, VALUE msg_rb) {
Message* msg = ruby_to_Message(msg_rb);
static VALUE Message_encode(int argc, VALUE* argv, VALUE klass) {
Message* msg = ruby_to_Message(argv[0]);
int options = 0;
const char* data;
size_t size;

if (CLASS_OF(msg_rb) != klass) {
if (CLASS_OF(argv[0]) != klass) {
rb_raise(rb_eArgError, "Message of wrong type.");
}

upb_Arena* arena = upb_Arena_New();
if (argc < 1 || argc > 2) {
rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
}

data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef), 0, arena,
&size);
if (argc == 2) {
VALUE hash_args = argv[1];
if (TYPE(hash_args) != T_HASH) {
rb_raise(rb_eArgError, "Expected hash arguments.");
}
VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));

if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth));
}
}

upb_Arena *arena = upb_Arena_New();

data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef),
options, arena, &size);

if (data) {
VALUE ret = rb_str_new(data, size);
Expand Down Expand Up @@ -1186,8 +1227,8 @@ VALUE build_class_from_descriptor(VALUE descriptor) {
rb_define_method(klass, "to_s", Message_inspect, 0);
rb_define_method(klass, "[]", Message_index, 1);
rb_define_method(klass, "[]=", Message_index_set, 2);
rb_define_singleton_method(klass, "decode", Message_decode, 1);
rb_define_singleton_method(klass, "encode", Message_encode, 1);
rb_define_singleton_method(klass, "decode", Message_decode, -1);
rb_define_singleton_method(klass, "encode", Message_encode, -1);
rb_define_singleton_method(klass, "decode_json", Message_decode_json, -1);
rb_define_singleton_method(klass, "encode_json", Message_encode_json, -1);
rb_define_singleton_method(klass, "descriptor", Message_descriptor, 0);
Expand Down
8 changes: 4 additions & 4 deletions ruby/lib/google/protobuf.rb
Expand Up @@ -59,16 +59,16 @@ class TypeError < ::TypeError; end
module Google
module Protobuf

def self.encode(msg)
msg.to_proto
def self.encode(msg, options = {})
msg.to_proto(options)
end

def self.encode_json(msg, options = {})
msg.to_json(options)
end

def self.decode(klass, proto)
klass.decode(proto)
def self.decode(klass, proto, options = {})
klass.decode(proto, options)
end

def self.decode_json(klass, json, options = {})
Expand Down
4 changes: 2 additions & 2 deletions ruby/lib/google/protobuf/message_exts.rb
Expand Up @@ -44,8 +44,8 @@ def to_json(options = {})
self.class.encode_json(self, options)
end

def to_proto
self.class.encode(self)
def to_proto(options = {})
self.class.encode(self, options)
end

end
Expand Down
4 changes: 2 additions & 2 deletions ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
Expand Up @@ -389,7 +389,7 @@ protected IRubyObject deepCopy(ThreadContext context) {
return newMap;
}

protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descriptor, int depth) {
protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descriptor, int depth, int maxRecursionDepth) {
List<DynamicMessage> list = new ArrayList<DynamicMessage>();
RubyClass rubyClass = (RubyClass) descriptor.msgclass(context);
FieldDescriptor keyField = descriptor.getField("key");
Expand All @@ -398,7 +398,7 @@ protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descr
RubyMessage mapMessage = (RubyMessage) rubyClass.newInstance(context, Block.NULL_BLOCK);
mapMessage.setField(context, keyField, key);
mapMessage.setField(context, valueField, table.get(key));
list.add(mapMessage.build(context, depth + 1));
list.add(mapMessage.build(context, depth + 1, maxRecursionDepth));
}
return list;
}
Expand Down
78 changes: 52 additions & 26 deletions ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
Expand Up @@ -39,6 +39,7 @@
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.OneofDescriptor;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
Expand Down Expand Up @@ -461,35 +462,63 @@ public static IRubyObject getDescriptor(ThreadContext context, IRubyObject recv)

/*
* call-seq:
* MessageClass.encode(msg) => bytes
* MessageClass.encode(msg, options = {}) => bytes
*
* Encodes the given message object to its serialized form in protocol buffers
* wire format.
* @param options [Hash] options for the encoder
* max_recursion_depth: set to maximum encoding depth for message (default is 64)
*/
@JRubyMethod(meta = true)
public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject value) {
if (recv != value.getMetaClass()) {
throw context.runtime.newArgumentError("Tried to encode a " + value.getMetaClass() + " message with " + recv);
@JRubyMethod(required = 1, optional = 1, meta = true)
public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject[] args) {
if (recv != args[0].getMetaClass()) {
throw context.runtime.newArgumentError("Tried to encode a " + args[0].getMetaClass() + " message with " + recv);
}
RubyMessage message = (RubyMessage) value;
return context.runtime.newString(new ByteList(message.build(context).toByteArray()));
RubyMessage message = (RubyMessage) args[0];
int maxRecursionDepthInt = SINK_MAXIMUM_NESTING;

if (args.length > 1) {
RubyHash options = (RubyHash) args[1];
IRubyObject maxRecursionDepth = options.fastARef(context.runtime.newSymbol("max_recursion_depth"));

if (maxRecursionDepth != null) {
maxRecursionDepthInt = ((RubyNumeric) maxRecursionDepth).getIntValue();
}
}
return context.runtime.newString(new ByteList(message.build(context, 0, maxRecursionDepthInt).toByteArray()));
}

/*
* call-seq:
* MessageClass.decode(data) => message
* MessageClass.decode(data, options = {}) => message
*
* Decodes the given data (as a string containing bytes in protocol buffers wire
* format) under the interpretation given by this message class's definition
* and returns a message object with the corresponding field values.
* @param options [Hash] options for the decoder
* max_recursion_depth: set to maximum decoding depth for message (default is 100)
*/
@JRubyMethod(meta = true)
public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject data) {
@JRubyMethod(required = 1, optional = 1, meta = true)
public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject[] args) {
IRubyObject data = args[0];
byte[] bin = data.convertToString().getBytes();
CodedInputStream input = CodedInputStream.newInstance(bin);
RubyMessage ret = (RubyMessage) ((RubyClass) recv).newInstance(context, Block.NULL_BLOCK);

if (args.length == 2) {
if (!(args[1] instanceof RubyHash)) {
throw context.runtime.newArgumentError("Expected hash arguments.");
}

IRubyObject maxRecursionDepth = ((RubyHash) args[1]).fastARef(context.runtime.newSymbol("max_recursion_depth"));
if (maxRecursionDepth != null) {
input.setRecursionLimit(((RubyNumeric) maxRecursionDepth).getIntValue());
}
}

try {
ret.builder.mergeFrom(bin);
} catch (InvalidProtocolBufferException e) {
ret.builder.mergeFrom(input);
} catch (Exception e) {
throw RaiseException.from(context.runtime, (RubyClass) context.runtime.getClassFromPath("Google::Protobuf::ParseError"), e.getMessage());
}

Expand Down Expand Up @@ -541,7 +570,7 @@ public static IRubyObject encodeJson(ThreadContext context, IRubyObject recv, IR
printer = printer.usingTypeRegistry(JsonFormat.TypeRegistry.newBuilder().add(message.descriptor).build());

try {
result = printer.print(message.build(context));
result = printer.print(message.build(context, 0, SINK_MAXIMUM_NESTING));
} catch (InvalidProtocolBufferException e) {
throw runtime.newRuntimeError(e.getMessage());
} catch (IllegalArgumentException e) {
Expand Down Expand Up @@ -635,12 +664,8 @@ public IRubyObject toHash(ThreadContext context) {
return ret;
}

protected DynamicMessage build(ThreadContext context) {
return build(context, 0);
}

protected DynamicMessage build(ThreadContext context, int depth) {
if (depth > SINK_MAXIMUM_NESTING) {
protected DynamicMessage build(ThreadContext context, int depth, int maxRecursionDepth) {
if (depth >= maxRecursionDepth) {
throw context.runtime.newRuntimeError("Maximum recursion depth exceeded during encoding.");
}

Expand All @@ -651,7 +676,7 @@ protected DynamicMessage build(ThreadContext context, int depth) {
if (value instanceof RubyMap) {
builder.clearField(fieldDescriptor);
RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor);
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) {
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) {
builder.addRepeatedField(fieldDescriptor, kv);
}

Expand All @@ -660,7 +685,7 @@ protected DynamicMessage build(ThreadContext context, int depth) {

builder.clearField(fieldDescriptor);
for (int i = 0; i < repeatedField.size(); i++) {
Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth,
Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth, maxRecursionDepth,
/*isDefaultValueForBytes*/ false);
builder.addRepeatedField(fieldDescriptor, item);
}
Expand All @@ -682,7 +707,7 @@ protected DynamicMessage build(ThreadContext context, int depth) {
fieldDescriptor.getFullName().equals("google.protobuf.FieldDescriptorProto.default_value")) {
isDefaultStringForBytes = true;
}
builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, isDefaultStringForBytes));
builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, maxRecursionDepth, isDefaultStringForBytes));
}
}

Expand All @@ -702,7 +727,7 @@ protected DynamicMessage build(ThreadContext context, int depth) {
builder.clearField(fieldDescriptor);
RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context,
fieldDescriptor);
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) {
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) {
builder.addRepeatedField(fieldDescriptor, kv);
}
}
Expand Down Expand Up @@ -814,7 +839,8 @@ private FieldDescriptor findField(ThreadContext context, IRubyObject fieldName,
// convert a ruby object to protobuf type, skip type check since it is checked on the way in
private Object convert(ThreadContext context,
FieldDescriptor fieldDescriptor,
IRubyObject value, int depth, boolean isDefaultStringForBytes) {
IRubyObject value, int depth, int maxRecursionDepth,
boolean isDefaultStringForBytes) {
Object val = null;
switch (fieldDescriptor.getType()) {
case INT32:
Expand Down Expand Up @@ -855,7 +881,7 @@ private Object convert(ThreadContext context,
}
break;
case MESSAGE:
val = ((RubyMessage) value).build(context, depth + 1);
val = ((RubyMessage) value).build(context, depth + 1, maxRecursionDepth);
break;
case ENUM:
EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType();
Expand Down Expand Up @@ -1214,7 +1240,7 @@ private void validateMessageType(ThreadContext context, FieldDescriptor fieldDes
private static final String CONST_SUFFIX = "_const";
private static final String HAS_PREFIX = "has_";
private static final String QUESTION_MARK = "?";
private static final int SINK_MAXIMUM_NESTING = 63;
private static final int SINK_MAXIMUM_NESTING = 64;

private Descriptor descriptor;
private DynamicMessage.Builder builder;
Expand Down
51 changes: 51 additions & 0 deletions ruby/tests/encode_decode_test.rb
Expand Up @@ -101,4 +101,55 @@ def test_json_name
assert_match json, "{\"CustomJsonName\":42}"
end

def test_decode_depth_limit
msg = A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
)
)
)
)
)
)
msg_encoded = A::B::C::TestMessage.encode(msg)
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json

assert_raise Google::Protobuf::ParseError do
A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 4 })
end

msg_out = A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 5 })
assert_match msg.to_json, msg_out.to_json
end

def test_encode_depth_limit
msg = A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
)
)
)
)
)
)
msg_encoded = A::B::C::TestMessage.encode(msg)
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json

assert_raise RuntimeError do
A::B::C::TestMessage.encode(msg, { max_recursion_depth: 5 })
end

msg_encoded = A::B::C::TestMessage.encode(msg, { max_recursion_depth: 6 })
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json
end

end

0 comments on commit fbe6ab2

Please sign in to comment.