Skip to content

Commit

Permalink
internal/cmd/generate-types: manual CSE of m.messageInfo()
Browse files Browse the repository at this point in the history
messageInfo() looks like this:

    func (ms *messageState) messageInfo() *MessageInfo {
    	mi := ms.LoadMessageInfo()
    	if mi == nil {
    		panic("invalid nil message info; this suggests memory corruption due to a race or shallow copy on the message struct")
    	}
    	return mi
    }

    func (ms *messageState) LoadMessageInfo() *MessageInfo {
    	return (*MessageInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&ms.atomicMessageInfo))))
    }

Which is an atomic load and a predictable branch. On x86, this 64-bit
load is just a MOV. On other platforms, like ARM64, there's actual
atomics involved (LDAR).

Meaning, it's cheap, but not free. Eliminate redundant copies of this
(Common Subexpression Elimination).

The newly added benchmarks improve by (geomean) 2.5%:

    $ benchstat pre post | head -10
    goarch: amd64
    cpu: AMD Ryzen Threadripper PRO 3995WX 64-Cores
                          │     pre     │                post                │
                          │   sec/op    │   sec/op     vs base               │
    Extension/Has/None-12   106.4n ± 2%   104.0n ± 2%  -2.21% (p=0.020 n=10)
    Extension/Has/Set-12    116.4n ± 1%   114.4n ± 2%  -1.76% (p=0.017 n=10)
    Extension/Get/None-12   184.2n ± 1%   181.0n ± 1%  -1.68% (p=0.003 n=10)
    Extension/Get/Set-12    144.5n ± 3%   140.7n ± 2%  -2.63% (p=0.041 n=10)
    Extension/Set-12        227.2n ± 2%   218.6n ± 2%  -3.81% (p=0.000 n=10)
    geomean                 149.6n        145.9n       -2.42%

I didn't test on ARM64, but the difference should be larger due to the
reduced atomics.

Change-Id: I8eebeb6f753425b743368a7f5c7be4d48537e5c3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/575036
Reviewed-by: Michael Stapelberg <stapelberg@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Commit-Queue: Nicolas Hillegeer <aktau@google.com>
Auto-Submit: Nicolas Hillegeer <aktau@google.com>
  • Loading branch information
aktau authored and gopherbot committed Apr 2, 2024
1 parent 55891d7 commit 8a74430
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 87 deletions.
38 changes: 38 additions & 0 deletions internal/benchmarks/micro/micro_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,41 @@ func BenchmarkRequired(b *testing.B) {
})
})
}

func BenchmarkExtension(b *testing.B) {
b.Run("Has/None", func(b *testing.B) {
m := &testpb.TestAllExtensions{}
for i := 0; i < b.N; i++ {
proto.HasExtension(m, testpb.E_OptionalNestedMessage)
}
})
b.Run("Has/Set", func(b *testing.B) {
m := &testpb.TestAllExtensions{}
ext := &testpb.TestAllExtensions_NestedMessage{A: proto.Int32(-32)}
proto.SetExtension(m, testpb.E_OptionalNestedMessage, ext)
for i := 0; i < b.N; i++ {
proto.HasExtension(m, testpb.E_OptionalNestedMessage)
}
})
b.Run("Get/None", func(b *testing.B) {
m := &testpb.TestAllExtensions{}
for i := 0; i < b.N; i++ {
proto.GetExtension(m, testpb.E_OptionalNestedMessage)
}
})
b.Run("Get/Set", func(b *testing.B) {
m := &testpb.TestAllExtensions{}
ext := &testpb.TestAllExtensions_NestedMessage{A: proto.Int32(-32)}
proto.SetExtension(m, testpb.E_OptionalNestedMessage, ext)
for i := 0; i < b.N; i++ {
proto.GetExtension(m, testpb.E_OptionalNestedMessage)
}
})
b.Run("Set", func(b *testing.B) {
m := &testpb.TestAllExtensions{}
ext := &testpb.TestAllExtensions_NestedMessage{A: proto.Int32(-32)}
for i := 0; i < b.N; i++ {
proto.SetExtension(m, testpb.E_OptionalNestedMessage, ext)
}
})
}
69 changes: 40 additions & 29 deletions internal/cmd/generate-types/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,9 @@ func (m *{{.}}) protoUnwrap() interface{} {
return m.pointer().AsIfaceOf(m.messageInfo().GoReflectType.Elem())
}
func (m *{{.}}) ProtoMethods() *protoiface.Methods {
m.messageInfo().init()
return &m.messageInfo().methods
mi := m.messageInfo()
mi.init()
return &mi.methods
}
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
Expand All @@ -758,8 +759,9 @@ func (m *{{.}}) ProtoMessageInfo() *MessageInfo {
}
func (m *{{.}}) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
m.messageInfo().init()
for _, ri := range m.messageInfo().rangeInfos {
mi := m.messageInfo()
mi.init()
for _, ri := range mi.rangeInfos {
switch ri := ri.(type) {
case *fieldInfo:
if ri.has(m.pointer()) {
Expand All @@ -769,77 +771,86 @@ func (m *{{.}}) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) b
}
case *oneofInfo:
if n := ri.which(m.pointer()); n > 0 {
fi := m.messageInfo().fields[n]
fi := mi.fields[n]
if !f(fi.fieldDesc, fi.get(m.pointer())) {
return
}
}
}
}
m.messageInfo().extensionMap(m.pointer()).Range(f)
mi.extensionMap(m.pointer()).Range(f)
}
func (m *{{.}}) Has(fd protoreflect.FieldDescriptor) bool {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
mi := m.messageInfo()
mi.init()
if fi, xt := mi.checkField(fd); fi != nil {
return fi.has(m.pointer())
} else {
return m.messageInfo().extensionMap(m.pointer()).Has(xt)
return mi.extensionMap(m.pointer()).Has(xt)
}
}
func (m *{{.}}) Clear(fd protoreflect.FieldDescriptor) {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
mi := m.messageInfo()
mi.init()
if fi, xt := mi.checkField(fd); fi != nil {
fi.clear(m.pointer())
} else {
m.messageInfo().extensionMap(m.pointer()).Clear(xt)
mi.extensionMap(m.pointer()).Clear(xt)
}
}
func (m *{{.}}) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
mi := m.messageInfo()
mi.init()
if fi, xt := mi.checkField(fd); fi != nil {
return fi.get(m.pointer())
} else {
return m.messageInfo().extensionMap(m.pointer()).Get(xt)
return mi.extensionMap(m.pointer()).Get(xt)
}
}
func (m *{{.}}) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
mi := m.messageInfo()
mi.init()
if fi, xt := mi.checkField(fd); fi != nil {
fi.set(m.pointer(), v)
} else {
m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
mi.extensionMap(m.pointer()).Set(xt, v)
}
}
func (m *{{.}}) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
mi := m.messageInfo()
mi.init()
if fi, xt := mi.checkField(fd); fi != nil {
return fi.mutable(m.pointer())
} else {
return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
return mi.extensionMap(m.pointer()).Mutable(xt)
}
}
func (m *{{.}}) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
mi := m.messageInfo()
mi.init()
if fi, xt := mi.checkField(fd); fi != nil {
return fi.newField()
} else {
return xt.New()
}
}
func (m *{{.}}) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
m.messageInfo().init()
if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
mi := m.messageInfo()
mi.init()
if oi := mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
return od.Fields().ByNumber(oi.which(m.pointer()))
}
panic("invalid oneof descriptor " + string(od.FullName()) + " for message " + string(m.Descriptor().FullName()))
}
func (m *{{.}}) GetUnknown() protoreflect.RawFields {
m.messageInfo().init()
return m.messageInfo().getUnknown(m.pointer())
mi := m.messageInfo()
mi.init()
return mi.getUnknown(m.pointer())
}
func (m *{{.}}) SetUnknown(b protoreflect.RawFields) {
m.messageInfo().init()
m.messageInfo().setUnknown(m.pointer(), b)
mi := m.messageInfo()
mi.init()
mi.setUnknown(m.pointer(), b)
}
func (m *{{.}}) IsValid() bool {
return !m.pointer().IsNil()
Expand Down

0 comments on commit 8a74430

Please sign in to comment.