diff --git a/CHANGELOG.md b/CHANGELOG.md index 4498ef93c..b3516ee52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ * Add TimeoutHandler for the HTTP API server. (#420, @miry) * Set Write and Read timeouts for HTTP API server connections. (#423, @miry) * Show uniq request id in API HTTP response. (#425, @miry) +* Add method to parse `stream.Direction` from string. + Allow to convert `stream.Direction` to string. (#430, @miry) # [2.4.0] - 2022-03-07 diff --git a/link.go b/link.go index f808a80e2..28820b19e 100644 --- a/link.go +++ b/link.go @@ -241,8 +241,5 @@ func (link *ToxicLink) RemoveToxic(toxic *toxics.ToxicWrapper) { // Direction returns the direction of the link (upstream or downstream). func (link *ToxicLink) Direction() string { - if link.direction == stream.Upstream { - return "upstream" - } - return "downstream" + return link.direction.String() } diff --git a/stream/direction.go b/stream/direction.go new file mode 100644 index 000000000..754d2a86b --- /dev/null +++ b/stream/direction.go @@ -0,0 +1,34 @@ +package stream + +import ( + "errors" + "strings" +) + +type Direction uint8 + +var ErrInvalidDirectionParameter error = errors.New("stream: invalid direction") + +const ( + Upstream Direction = iota + Downstream + NumDirections +) + +func (d Direction) String() string { + if d >= NumDirections { + return "num_directions" + } + return [...]string{"upstream", "downstream"}[d] +} + +func ParseDirection(value string) (Direction, error) { + switch strings.ToLower(value) { + case "downstream": + return Downstream, nil + case "upstream": + return Upstream, nil + } + + return NumDirections, ErrInvalidDirectionParameter +} diff --git a/stream/direction_test.go b/stream/direction_test.go new file mode 100644 index 000000000..1a875cfc5 --- /dev/null +++ b/stream/direction_test.go @@ -0,0 +1,67 @@ +package stream_test + +import ( + "testing" + + "github.com/Shopify/toxiproxy/v2/stream" +) + +func TestDirection_String(t *testing.T) { + testCases := []struct { + name string + direction stream.Direction + expected string + }{ + {"Downstream to string", stream.Downstream, "downstream"}, + {"Upstream to string", stream.Upstream, "upstream"}, + {"NumDirections to string", stream.NumDirections, "num_directions"}, + {"Upstream via number direction to string", stream.Direction(0), "upstream"}, + {"Downstream via number direction to string", stream.Direction(1), "downstream"}, + {"High number direction to string", stream.Direction(5), "num_directions"}, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + actual := tc.direction.String() + if actual != tc.expected { + t.Errorf("got \"%s\"; expected \"%s\"", actual, tc.expected) + } + }) + } +} + +func TestParseDirection(t *testing.T) { + testCases := []struct { + name string + input string + expected stream.Direction + err error + }{ + {"parse empty", "", stream.NumDirections, stream.ErrInvalidDirectionParameter}, + {"parse upstream", "upstream", stream.Upstream, nil}, + {"parse downstream", "downstream", stream.Downstream, nil}, + {"parse unknown", "unknown", stream.NumDirections, stream.ErrInvalidDirectionParameter}, + {"parse number", "-123", stream.NumDirections, stream.ErrInvalidDirectionParameter}, + {"parse upper case", "DOWNSTREAM", stream.Downstream, nil}, + {"parse camel case", "UpStream", stream.Upstream, nil}, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + actual, err := stream.ParseDirection(tc.input) + if actual != tc.expected { + t.Errorf("got \"%s\"; expected \"%s\"", actual, tc.expected) + } + + if err != tc.err { + t.Errorf("got \"%s\"; expected \"%s\"", err, tc.err) + } + }) + } +} diff --git a/stream/io_chan.go b/stream/io_chan.go index bb6349b8f..d7bf56b36 100644 --- a/stream/io_chan.go +++ b/stream/io_chan.go @@ -6,14 +6,6 @@ import ( "time" ) -type Direction uint8 - -const ( - Upstream Direction = iota - Downstream - NumDirections -) - // Stores a slice of bytes with its receive timestamp. type StreamChunk struct { Data []byte diff --git a/toxic_collection.go b/toxic_collection.go index 68034f24a..050232361 100644 --- a/toxic_collection.go +++ b/toxic_collection.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "strings" "sync" "github.com/rs/zerolog" @@ -98,14 +97,11 @@ func (c *ToxicCollection) AddToxicJson(data io.Reader) (*toxics.ToxicWrapper, er return nil, joinError(err, ErrBadRequestBody) } - switch strings.ToLower(wrapper.Stream) { - case "downstream": - wrapper.Direction = stream.Downstream - case "upstream": - wrapper.Direction = stream.Upstream - default: + wrapper.Direction, err = stream.ParseDirection(wrapper.Stream) + if err != nil { return nil, ErrInvalidStream } + if wrapper.Name == "" { wrapper.Name = fmt.Sprintf("%s_%s", wrapper.Type, wrapper.Stream) }