Skip to content

Commit

Permalink
Add protection for stack-overflows for nested keys
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Scheid <p.scheid92@gmail.com>
  • Loading branch information
pscheid92 committed Jan 6, 2023
1 parent 50ec3d4 commit c1a65d5
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 9 deletions.
26 changes: 17 additions & 9 deletions pkg/strvals/literal_parser.go
Expand Up @@ -64,7 +64,7 @@ func newLiteralParser(sc *bytes.Buffer, data map[string]interface{}) *literalPar

func (t *literalParser) parse() error {
for {
err := t.key(t.data)
err := t.key(t.data, 0)
if err == nil {
continue
}
Expand All @@ -89,7 +89,7 @@ func runesUntilLiteral(in io.RuneReader, stop map[rune]bool) ([]rune, rune, erro
}
}

func (t *literalParser) key(data map[string]interface{}) (reterr error) {
func (t *literalParser) key(data map[string]interface{}, nestedNameLevel int) (reterr error) {
defer func() {
if r := recover(); r != nil {
reterr = fmt.Errorf("unable to parse key: %s", r)
Expand All @@ -114,18 +114,26 @@ func (t *literalParser) key(data map[string]interface{}) (reterr error) {
return nil

case lastRune == '.':
// Check value name is within the maximum nested name level
nestedNameLevel++
if nestedNameLevel > MaxNestedNameLevel {
return fmt.Errorf("value name nested level is greater than maximum supported nested level of %d", MaxNestedNameLevel)
}

// first, create or find the target map in the given data
inner := map[string]interface{}{}
if _, ok := data[string(key)]; ok {
inner = data[string(key)].(map[string]interface{})
}

// recurse on sub-tree with remaining data
err := t.key(inner)
if len(inner) == 0 {
err := t.key(inner, nestedNameLevel)
if err == nil && len(inner) == 0 {
return errors.Errorf("key map %q has no value", string(key))
}
set(data, string(key), inner)
if len(inner) != 0 {
set(data, string(key), inner)
}
return err

case lastRune == '[':
Expand All @@ -143,7 +151,7 @@ func (t *literalParser) key(data map[string]interface{}) (reterr error) {
}

// now we need to get the value after the ]
list, err = t.listItem(list, i)
list, err = t.listItem(list, i, nestedNameLevel)
set(data, kk, list)
return err
}
Expand All @@ -162,7 +170,7 @@ func (t *literalParser) keyIndex() (int, error) {
return strconv.Atoi(string(v))
}

func (t *literalParser) listItem(list []interface{}, i int) ([]interface{}, error) {
func (t *literalParser) listItem(list []interface{}, i, nestedNameLevel int) ([]interface{}, error) {
if i < 0 {
return list, fmt.Errorf("negative %d index not allowed", i)
}
Expand Down Expand Up @@ -196,7 +204,7 @@ func (t *literalParser) listItem(list []interface{}, i int) ([]interface{}, erro
}

// recurse
err := t.key(inner)
err := t.key(inner, nestedNameLevel)
if err != nil {
return list, err
}
Expand All @@ -218,7 +226,7 @@ func (t *literalParser) listItem(list []interface{}, i int) ([]interface{}, erro
}

// Now we need to get the value after the ].
list2, err := t.listItem(crtList, nextI)
list2, err := t.listItem(crtList, nextI, nestedNameLevel)
if err != nil {
return list, err
}
Expand Down
65 changes: 65 additions & 0 deletions pkg/strvals/literal_parser_test.go
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
package strvals

import (
"fmt"
"testing"

"sigs.k8s.io/yaml"
Expand Down Expand Up @@ -413,3 +414,67 @@ func TestParseLiteralInto(t *testing.T) {
}
}
}

func TestParseLiteralNestedLevels(t *testing.T) {
var keyMultipleNestedLevels string

for i := 1; i <= MaxNestedNameLevel+2; i++ {
tmpStr := fmt.Sprintf("name%d", i)
if i <= MaxNestedNameLevel+1 {
tmpStr = tmpStr + "."
}
keyMultipleNestedLevels += tmpStr
}

tests := []struct {
str string
expect map[string]interface{}
err bool
errStr string
}{
{
"outer.middle.inner=value",
map[string]interface{}{"outer": map[string]interface{}{"middle": map[string]interface{}{"inner": "value"}}},
false,
"",
},
{
str: keyMultipleNestedLevels + "=value",
err: true,
errStr: fmt.Sprintf("value name nested level is greater than maximum supported nested level of %d", MaxNestedNameLevel),
},
}

for _, tt := range tests {
got, err := ParseLiteral(tt.str)
if err != nil {
if tt.err {
if tt.errStr != "" {
if err.Error() != tt.errStr {
t.Errorf("Expected error: %s. Got error: %s", tt.errStr, err.Error())
}
}
continue
}
t.Fatalf("%s: %s", tt.str, err)
}

if tt.err {
t.Errorf("%s: Expected error. Got nil", tt.str)
}

y1, err := yaml.Marshal(tt.expect)
if err != nil {
t.Fatal(err)
}

y2, err := yaml.Marshal(got)
if err != nil {
t.Fatalf("Error serializing parsed value: %s", err)
}

if string(y1) != string(y2) {
t.Errorf("%s: Expected:\n%s\nGot:\n%s", tt.str, y1, y2)
}
}
}

0 comments on commit c1a65d5

Please sign in to comment.