Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ruby] Message.decode/encode: Add max_recursion_depth option #9218

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 53 additions & 12 deletions ruby/ext/google/protobuf_c/message.c
Expand Up @@ -953,13 +953,35 @@ static VALUE Message_index_set(VALUE _self, VALUE field_name, VALUE value) {

/*
* call-seq:
* MessageClass.decode(data) => message
* MessageClass.decode(data, options) => message
*
* Decodes the given data (as a string containing bytes in protocol buffers wire
* format) under the interpretration given by this message class's definition
* and returns a message object with the corresponding field values.
* @param options [Hash] options for the decoder
* max_recursion_depth: set to maximum decoding depth for message (default is 64)
*/
static VALUE Message_decode(VALUE klass, VALUE data) {
static VALUE Message_decode(int argc, VALUE* argv, VALUE klass) {
VALUE data = argv[0];
int options = 0;

if (argc < 1 || argc > 2) {
rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
}

if (argc == 2) {
VALUE hash_args = argv[1];
if (TYPE(hash_args) != T_HASH) {
rb_raise(rb_eArgError, "Expected hash arguments.");
}

VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));

if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth));
}
}

if (TYPE(data) != T_STRING) {
rb_raise(rb_eArgError, "Expected string for binary protobuf data.");
}
Expand All @@ -969,7 +991,7 @@ static VALUE Message_decode(VALUE klass, VALUE data) {

upb_DecodeStatus status = upb_Decode(
RSTRING_PTR(data), RSTRING_LEN(data), (upb_Message*)msg->msg,
upb_MessageDef_MiniTable(msg->msgdef), NULL, 0, Arena_get(msg->arena));
upb_MessageDef_MiniTable(msg->msgdef), NULL, options, Arena_get(msg->arena));

if (status != kUpb_DecodeStatus_Ok) {
rb_raise(cParseError, "Error occurred during parsing");
Expand Down Expand Up @@ -1043,24 +1065,43 @@ static VALUE Message_decode_json(int argc, VALUE* argv, VALUE klass) {

/*
* call-seq:
* MessageClass.encode(msg) => bytes
* MessageClass.encode(msg, options) => bytes
*
* Encodes the given message object to its serialized form in protocol buffers
* wire format.
* @param options [Hash] options for the encoder
* max_recursion_depth: set to maximum encoding depth for message (default is 64)
*/
static VALUE Message_encode(VALUE klass, VALUE msg_rb) {
Message* msg = ruby_to_Message(msg_rb);
static VALUE Message_encode(int argc, VALUE* argv, VALUE klass) {
Message* msg = ruby_to_Message(argv[0]);
int options = 0;
const char* data;
size_t size;

if (CLASS_OF(msg_rb) != klass) {
if (CLASS_OF(argv[0]) != klass) {
rb_raise(rb_eArgError, "Message of wrong type.");
}

upb_Arena* arena = upb_Arena_New();
if (argc < 1 || argc > 2) {
rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
}

data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef), 0, arena,
&size);
if (argc == 2) {
VALUE hash_args = argv[1];
if (TYPE(hash_args) != T_HASH) {
rb_raise(rb_eArgError, "Expected hash arguments.");
}
VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));

if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth));
}
}

upb_Arena *arena = upb_Arena_New();

data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef),
options, arena, &size);

if (data) {
VALUE ret = rb_str_new(data, size);
Expand Down Expand Up @@ -1186,8 +1227,8 @@ VALUE build_class_from_descriptor(VALUE descriptor) {
rb_define_method(klass, "to_s", Message_inspect, 0);
rb_define_method(klass, "[]", Message_index, 1);
rb_define_method(klass, "[]=", Message_index_set, 2);
rb_define_singleton_method(klass, "decode", Message_decode, 1);
rb_define_singleton_method(klass, "encode", Message_encode, 1);
rb_define_singleton_method(klass, "decode", Message_decode, -1);
rb_define_singleton_method(klass, "encode", Message_encode, -1);
rb_define_singleton_method(klass, "decode_json", Message_decode_json, -1);
rb_define_singleton_method(klass, "encode_json", Message_encode_json, -1);
rb_define_singleton_method(klass, "descriptor", Message_descriptor, 0);
Expand Down
8 changes: 4 additions & 4 deletions ruby/lib/google/protobuf.rb
Expand Up @@ -59,16 +59,16 @@ class TypeError < ::TypeError; end
module Google
module Protobuf

def self.encode(msg)
msg.to_proto
def self.encode(msg, options = {})
msg.to_proto(options)
end

def self.encode_json(msg, options = {})
msg.to_json(options)
end

def self.decode(klass, proto)
klass.decode(proto)
def self.decode(klass, proto, options = {})
klass.decode(proto, options)
end

def self.decode_json(klass, json, options = {})
Expand Down
4 changes: 2 additions & 2 deletions ruby/lib/google/protobuf/message_exts.rb
Expand Up @@ -44,8 +44,8 @@ def to_json(options = {})
self.class.encode_json(self, options)
end

def to_proto
self.class.encode(self)
def to_proto(options = {})
self.class.encode(self, options)
end

end
Expand Down
4 changes: 2 additions & 2 deletions ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
Expand Up @@ -389,7 +389,7 @@ protected IRubyObject deepCopy(ThreadContext context) {
return newMap;
}

protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descriptor, int depth) {
protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descriptor, int depth, int maxRecursionDepth) {
List<DynamicMessage> list = new ArrayList<DynamicMessage>();
RubyClass rubyClass = (RubyClass) descriptor.msgclass(context);
FieldDescriptor keyField = descriptor.getField("key");
Expand All @@ -398,7 +398,7 @@ protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descr
RubyMessage mapMessage = (RubyMessage) rubyClass.newInstance(context, Block.NULL_BLOCK);
mapMessage.setField(context, keyField, key);
mapMessage.setField(context, valueField, table.get(key));
list.add(mapMessage.build(context, depth + 1));
list.add(mapMessage.build(context, depth + 1, maxRecursionDepth));
}
return list;
}
Expand Down
78 changes: 52 additions & 26 deletions ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
Expand Up @@ -39,6 +39,7 @@
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.OneofDescriptor;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
Expand Down Expand Up @@ -461,35 +462,63 @@ public static IRubyObject getDescriptor(ThreadContext context, IRubyObject recv)

/*
* call-seq:
* MessageClass.encode(msg) => bytes
* MessageClass.encode(msg, options = {}) => bytes
*
* Encodes the given message object to its serialized form in protocol buffers
* wire format.
* @param options [Hash] options for the encoder
* max_recursion_depth: set to maximum encoding depth for message (default is 64)
*/
@JRubyMethod(meta = true)
public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject value) {
if (recv != value.getMetaClass()) {
throw context.runtime.newArgumentError("Tried to encode a " + value.getMetaClass() + " message with " + recv);
@JRubyMethod(required = 1, optional = 1, meta = true)
public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject[] args) {
if (recv != args[0].getMetaClass()) {
throw context.runtime.newArgumentError("Tried to encode a " + args[0].getMetaClass() + " message with " + recv);
}
RubyMessage message = (RubyMessage) value;
return context.runtime.newString(new ByteList(message.build(context).toByteArray()));
RubyMessage message = (RubyMessage) args[0];
int maxRecursionDepthInt = SINK_MAXIMUM_NESTING;

if (args.length > 1) {
RubyHash options = (RubyHash) args[1];
IRubyObject maxRecursionDepth = options.fastARef(context.runtime.newSymbol("max_recursion_depth"));

if (maxRecursionDepth != null) {
maxRecursionDepthInt = ((RubyNumeric) maxRecursionDepth).getIntValue();
}
}
return context.runtime.newString(new ByteList(message.build(context, 0, maxRecursionDepthInt).toByteArray()));
}

/*
* call-seq:
* MessageClass.decode(data) => message
* MessageClass.decode(data, options = {}) => message
*
* Decodes the given data (as a string containing bytes in protocol buffers wire
* format) under the interpretation given by this message class's definition
* and returns a message object with the corresponding field values.
* @param options [Hash] options for the decoder
* max_recursion_depth: set to maximum decoding depth for message (default is 100)
*/
@JRubyMethod(meta = true)
public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject data) {
@JRubyMethod(required = 1, optional = 1, meta = true)
public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject[] args) {
IRubyObject data = args[0];
byte[] bin = data.convertToString().getBytes();
CodedInputStream input = CodedInputStream.newInstance(bin);
RubyMessage ret = (RubyMessage) ((RubyClass) recv).newInstance(context, Block.NULL_BLOCK);

if (args.length == 2) {
if (!(args[1] instanceof RubyHash)) {
throw context.runtime.newArgumentError("Expected hash arguments.");
}

IRubyObject maxRecursionDepth = ((RubyHash) args[1]).fastARef(context.runtime.newSymbol("max_recursion_depth"));
if (maxRecursionDepth != null) {
input.setRecursionLimit(((RubyNumeric) maxRecursionDepth).getIntValue());
}
}

try {
ret.builder.mergeFrom(bin);
} catch (InvalidProtocolBufferException e) {
ret.builder.mergeFrom(input);
} catch (Exception e) {
throw RaiseException.from(context.runtime, (RubyClass) context.runtime.getClassFromPath("Google::Protobuf::ParseError"), e.getMessage());
}

Expand Down Expand Up @@ -541,7 +570,7 @@ public static IRubyObject encodeJson(ThreadContext context, IRubyObject recv, IR
printer = printer.usingTypeRegistry(JsonFormat.TypeRegistry.newBuilder().add(message.descriptor).build());

try {
result = printer.print(message.build(context));
result = printer.print(message.build(context, 0, SINK_MAXIMUM_NESTING));
} catch (InvalidProtocolBufferException e) {
throw runtime.newRuntimeError(e.getMessage());
} catch (IllegalArgumentException e) {
Expand Down Expand Up @@ -635,12 +664,8 @@ public IRubyObject toHash(ThreadContext context) {
return ret;
}

protected DynamicMessage build(ThreadContext context) {
return build(context, 0);
}

protected DynamicMessage build(ThreadContext context, int depth) {
if (depth > SINK_MAXIMUM_NESTING) {
protected DynamicMessage build(ThreadContext context, int depth, int maxRecursionDepth) {
if (depth >= maxRecursionDepth) {
Copy link
Contributor Author

@lfittl lfittl Nov 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is changed to match the UPB library behavior (> => >=), and SINK_MAXIMUM_NESTING is incremented by one accordingly.

throw context.runtime.newRuntimeError("Maximum recursion depth exceeded during encoding.");
}

Expand All @@ -651,7 +676,7 @@ protected DynamicMessage build(ThreadContext context, int depth) {
if (value instanceof RubyMap) {
builder.clearField(fieldDescriptor);
RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor);
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) {
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) {
builder.addRepeatedField(fieldDescriptor, kv);
}

Expand All @@ -660,7 +685,7 @@ protected DynamicMessage build(ThreadContext context, int depth) {

builder.clearField(fieldDescriptor);
for (int i = 0; i < repeatedField.size(); i++) {
Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth,
Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth, maxRecursionDepth,
/*isDefaultValueForBytes*/ false);
builder.addRepeatedField(fieldDescriptor, item);
}
Expand All @@ -682,7 +707,7 @@ protected DynamicMessage build(ThreadContext context, int depth) {
fieldDescriptor.getFullName().equals("google.protobuf.FieldDescriptorProto.default_value")) {
isDefaultStringForBytes = true;
}
builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, isDefaultStringForBytes));
builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, maxRecursionDepth, isDefaultStringForBytes));
}
}

Expand All @@ -702,7 +727,7 @@ protected DynamicMessage build(ThreadContext context, int depth) {
builder.clearField(fieldDescriptor);
RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context,
fieldDescriptor);
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) {
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) {
builder.addRepeatedField(fieldDescriptor, kv);
}
}
Expand Down Expand Up @@ -814,7 +839,8 @@ private FieldDescriptor findField(ThreadContext context, IRubyObject fieldName,
// convert a ruby object to protobuf type, skip type check since it is checked on the way in
private Object convert(ThreadContext context,
FieldDescriptor fieldDescriptor,
IRubyObject value, int depth, boolean isDefaultStringForBytes) {
IRubyObject value, int depth, int maxRecursionDepth,
boolean isDefaultStringForBytes) {
Object val = null;
switch (fieldDescriptor.getType()) {
case INT32:
Expand Down Expand Up @@ -855,7 +881,7 @@ private Object convert(ThreadContext context,
}
break;
case MESSAGE:
val = ((RubyMessage) value).build(context, depth + 1);
val = ((RubyMessage) value).build(context, depth + 1, maxRecursionDepth);
break;
case ENUM:
EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType();
Expand Down Expand Up @@ -1214,7 +1240,7 @@ private void validateMessageType(ThreadContext context, FieldDescriptor fieldDes
private static final String CONST_SUFFIX = "_const";
private static final String HAS_PREFIX = "has_";
private static final String QUESTION_MARK = "?";
private static final int SINK_MAXIMUM_NESTING = 63;
private static final int SINK_MAXIMUM_NESTING = 64;

private Descriptor descriptor;
private DynamicMessage.Builder builder;
Expand Down
51 changes: 51 additions & 0 deletions ruby/tests/encode_decode_test.rb
Expand Up @@ -101,4 +101,55 @@ def test_json_name
assert_match json, "{\"CustomJsonName\":42}"
end

def test_decode_depth_limit
msg = A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
)
)
)
)
)
)
msg_encoded = A::B::C::TestMessage.encode(msg)
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json

assert_raise Google::Protobuf::ParseError do
A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 4 })
end

msg_out = A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 5 })
assert_match msg.to_json, msg_out.to_json
end

def test_encode_depth_limit
msg = A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
)
)
)
)
)
)
msg_encoded = A::B::C::TestMessage.encode(msg)
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json

assert_raise RuntimeError do
A::B::C::TestMessage.encode(msg, { max_recursion_depth: 5 })
end

msg_encoded = A::B::C::TestMessage.encode(msg, { max_recursion_depth: 6 })
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json
end

end