Skip to content

Commit

Permalink
Merge pull request #741 from libp2p/feat/refactor-get-values
Browse files Browse the repository at this point in the history
Extract validation from ProtocolMessenger
  • Loading branch information
aschmahmann committed Aug 17, 2021
2 parents f509e77 + 394a152 commit 7a8aeb6
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 105 deletions.
2 changes: 1 addition & 1 deletion dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error)

dht.Validator = cfg.Validator
dht.msgSender = net.NewMessageSenderImpl(h, dht.protocols)
dht.protoMessenger, err = pb.NewProtocolMessenger(dht.msgSender, pb.WithValidator(dht.Validator))
dht.protoMessenger, err = pb.NewProtocolMessenger(dht.msgSender)
if err != nil {
return nil, err
}
Expand Down
77 changes: 27 additions & 50 deletions fullrt/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func NewFullRT(h host.Host, protocolPrefix protocol.ID, options ...Option) (*Ful
}

ms := net.NewMessageSenderImpl(h, []protocol.ID{dhtcfg.ProtocolPrefix + "/kad/1.0.0"})
protoMessenger, err := dht_pb.NewProtocolMessenger(ms, dht_pb.WithValidator(dhtcfg.Validator))
protoMessenger, err := dht_pb.NewProtocolMessenger(ms)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -596,27 +596,6 @@ func (dht *FullRT) searchValueQuorum(ctx context.Context, key string, valCh <-ch
})
}

// GetValues gets nvals values corresponding to the given key.
func (dht *FullRT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) {
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

queryCtx, cancel := context.WithCancel(ctx)
defer cancel()
valCh, _ := dht.getValues(queryCtx, key, nil)

out := make([]RecvdVal, 0, nvals)
for val := range valCh {
out = append(out, val)
if len(out) == nvals {
cancel()
}
}

return out, ctx.Err()
}

func (dht *FullRT) processValues(ctx context.Context, key string, vals <-chan RecvdVal,
newVal func(ctx context.Context, v RecvdVal, better bool) bool) (best []byte, peersWithBest map[peer.ID]struct{}, aborted bool) {
loop:
Expand Down Expand Up @@ -720,44 +699,42 @@ func (dht *FullRT) getValues(ctx context.Context, key string, stopQuery chan str
})

rec, peers, err := dht.protoMessenger.GetValue(ctx, p, key)
switch err {
case routing.ErrNotFound:
// in this case, they responded with nothing,
// still send a notification so listeners can know the
// request has completed 'successfully'
routing.PublishQueryEvent(ctx, &routing.QueryEvent{
Type: routing.PeerResponse,
ID: p,
})
return nil
case nil, internal.ErrInvalidRecord:
// in either of these cases, we want to keep going
default:
if err != nil {
return err
}

// TODO: What should happen if the record is invalid?
// Pre-existing code counted it towards the quorum, but should it?
if rec != nil && rec.GetValue() != nil {
rv := RecvdVal{
Val: rec.GetValue(),
From: p,
}

select {
case valCh <- rv:
case <-ctx.Done():
return ctx.Err()
}
}

// For DHT query command
routing.PublishQueryEvent(ctx, &routing.QueryEvent{
Type: routing.PeerResponse,
ID: p,
Responses: peers,
})

if rec == nil {
return nil
}

val := rec.GetValue()
if val == nil {
logger.Debug("received a nil record value")
return nil
}
if err := dht.Validator.Validate(key, val); err != nil {
// make sure record is valid
logger.Debugw("received invalid record (discarded)", "error", err)
return nil
}

// the record is present and valid, send it out for processing
select {
case valCh <- RecvdVal{
Val: val,
From: p,
}:
case <-ctx.Done():
return ctx.Err()
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ package internal

import "errors"

var ErrInvalidRecord = errors.New("received invalid record")
var ErrIncorrectRecord = errors.New("received incorrect record")
32 changes: 8 additions & 24 deletions pb/protocol_messenger.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ import (
"errors"
"fmt"

logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/routing"

logging "github.com/ipfs/go-log"
record "github.com/libp2p/go-libp2p-record"
recpb "github.com/libp2p/go-libp2p-record/pb"
"github.com/multiformats/go-multihash"

Expand All @@ -27,19 +24,11 @@ var logger = logging.Logger("dht")
// Note: the ProtocolMessenger's MessageSender still needs to deal with some wire protocol details such as using
// varint-delineated protobufs
type ProtocolMessenger struct {
m MessageSender
validator record.Validator
m MessageSender
}

type ProtocolMessengerOption func(*ProtocolMessenger) error

func WithValidator(validator record.Validator) ProtocolMessengerOption {
return func(messenger *ProtocolMessenger) error {
messenger.validator = validator
return nil
}
}

// NewProtocolMessenger creates a new ProtocolMessenger that is used for sending DHT messages to peers and processing
// their responses.
func NewProtocolMessenger(msgSender MessageSender, opts ...ProtocolMessengerOption) (*ProtocolMessenger, error) {
Expand Down Expand Up @@ -99,21 +88,16 @@ func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string
// Success! We were given the value
logger.Debug("got value")

// make sure record is valid.
err = pm.validator.Validate(string(rec.GetKey()), rec.GetValue())
if err != nil {
logger.Debug("received invalid record (discarded)")
// return a sentinel to signify an invalid record was received
return nil, peers, internal.ErrInvalidRecord
// Check that record matches the one we are looking for (validation of the record does not happen here)
if !bytes.Equal([]byte(key), rec.GetKey()) {
logger.Debug("received incorrect record")
return nil, nil, internal.ErrIncorrectRecord
}
return rec, peers, err
}

if len(peers) > 0 {
return nil, peers, nil
return rec, peers, err
}

return nil, nil, routing.ErrNotFound
return nil, peers, nil
}

// GetClosestPeers asks a peer to return the K (a DHT-wide parameter) DHT server peers closest in XOR space to the id
Expand Down
56 changes: 27 additions & 29 deletions routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,45 +297,43 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st
})

rec, peers, err := dht.protoMessenger.GetValue(ctx, p, key)
switch err {
case routing.ErrNotFound:
// in this case, they responded with nothing,
// still send a notification so listeners can know the
// request has completed 'successfully'
routing.PublishQueryEvent(ctx, &routing.QueryEvent{
Type: routing.PeerResponse,
ID: p,
})
return nil, err
case nil, internal.ErrInvalidRecord:
// in either of these cases, we want to keep going
default:
if err != nil {
return nil, err
}

// TODO: What should happen if the record is invalid?
// Pre-existing code counted it towards the quorum, but should it?
if rec != nil && rec.GetValue() != nil {
rv := recvdVal{
Val: rec.GetValue(),
From: p,
}

select {
case valCh <- rv:
case <-ctx.Done():
return nil, ctx.Err()
}
}

// For DHT query command
routing.PublishQueryEvent(ctx, &routing.QueryEvent{
Type: routing.PeerResponse,
ID: p,
Responses: peers,
})

return peers, err
if rec == nil {
return peers, nil
}

val := rec.GetValue()
if val == nil {
logger.Debug("received a nil record value")
return peers, nil
}
if err := dht.Validator.Validate(key, val); err != nil {
// make sure record is valid
logger.Debugw("received invalid record (discarded)", "error", err)
return peers, nil
}

// the record is present and valid, send it out for processing
select {
case valCh <- recvdVal{
Val: val,
From: p,
}:
case <-ctx.Done():
return nil, ctx.Err()
}

return peers, nil
},
func() bool {
select {
Expand Down

0 comments on commit 7a8aeb6

Please sign in to comment.