From 2b1ed8ee915d0d6eb6543aa4269e3361aa957672 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 14 Nov 2025 15:41:31 +0100 Subject: [PATCH 1/2] Implement limit directive for msgp array/map size bounds checking Add `//msgp:limit` directive to prevent DoS attacks by limiting array/slice and map sizes during msgp serialization/deserialization operations. Features: - Unified directive syntax: //msgp:limit arrays:n maps:n marshal:true/false - Applies only to dynamic slices/maps, excludes fixed arrays (compile-time sized) - Optional marshal-time enforcement with marshal:true parameter - Default limits of math.MaxUint32 when not specified - Returns msgp.ErrLimitExceeded when limits exceeded - Per-file unique constant generation using CRC32 hash to avoid collisions Usage examples: `//msgp:limit arrays:100 maps:50` (unmarshal limits only) `//msgp:limit arrays:100 maps:50 marshal:true` (both marshal and unmarshal limits) --- _generated/limits.go | 24 +++ _generated/limits2.go | 10 + _generated/limits_test.go | 396 +++++++++++++++++++++++++++++++++++ _generated/marshal_limits.go | 12 ++ gen/decode.go | 37 +++- gen/encode.go | 25 ++- gen/marshal.go | 23 +- gen/spec.go | 30 ++- gen/unmarshal.go | 35 +++- parse/directives.go | 35 ++++ parse/getast.go | 10 + printer/print.go | 34 ++- 12 files changed, 649 insertions(+), 22 deletions(-) create mode 100644 _generated/limits.go create mode 100644 _generated/limits2.go create mode 100644 _generated/limits_test.go create mode 100644 _generated/marshal_limits.go diff --git a/_generated/limits.go b/_generated/limits.go new file mode 100644 index 00000000..463976d4 --- /dev/null +++ b/_generated/limits.go @@ -0,0 +1,24 @@ +//msgp:limit arrays:100 maps:50 + +package _generated + +//go:generate msgp + +// Test structures for limit directive +type LimitedData struct { + SmallArray [10]int `msg:"small_array"` + LargeSlice []byte `msg:"large_slice"` + SmallMap map[string]int `msg:"small_map"` +} + +type UnlimitedData struct { + BigArray [1000]int `msg:"big_array"` + BigSlice []string `msg:"big_slice"` + BigMap map[string][]int `msg:"big_map"` +} + +type LimitTestData struct { + SmallArray [10]int `msg:"small_array"` + LargeSlice []byte `msg:"large_slice"` + SmallMap map[string]int `msg:"small_map"` +} diff --git a/_generated/limits2.go b/_generated/limits2.go new file mode 100644 index 00000000..e25f105f --- /dev/null +++ b/_generated/limits2.go @@ -0,0 +1,10 @@ +package _generated + +//go:generate msgp + +//msgp:limit arrays:200 maps:100 + +type LimitTestData2 struct { + BigArray [20]int `msg:"big_array"` + BigMap map[string]int `msg:"big_map"` +} diff --git a/_generated/limits_test.go b/_generated/limits_test.go new file mode 100644 index 00000000..11be1e1d --- /dev/null +++ b/_generated/limits_test.go @@ -0,0 +1,396 @@ +package _generated + +import ( + "bytes" + "fmt" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestSliceLimitEnforcement(t *testing.T) { + data := UnlimitedData{} + + // Test slice limit with DecodeMsg (using big_slice which is []string) + t.Run("DecodeMsg_SliceLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 150) // Exceeds limit of 100 + + reader := msgp.NewReader(bytes.NewReader(buf)) + err := data.DecodeMsg(reader) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded, got %v", err) + } + }) + + // Test slice limit with UnmarshalMsg + t.Run("UnmarshalMsg_SliceLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 150) // Exceeds limit of 100 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded, got %v", err) + } + }) + + // Test that slices within limit work fine + t.Run("SliceWithinLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 50) // Within limit + for i := 0; i < 50; i++ { + buf = msgp.AppendString(buf, "test") + } + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Unexpected error for slice within limit: %v", err) + } + }) +} + +func TestMapLimitEnforcement(t *testing.T) { + data := LimitTestData{} + + // Test map limit with DecodeMsg + t.Run("DecodeMsg_MapLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "small_map") + buf = msgp.AppendMapHeader(buf, 60) // Exceeds limit of 50 + + reader := msgp.NewReader(bytes.NewReader(buf)) + err := data.DecodeMsg(reader) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded, got %v", err) + } + }) + + // Test map limit with UnmarshalMsg + t.Run("UnmarshalMsg_MapLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "small_map") + buf = msgp.AppendMapHeader(buf, 60) // Exceeds limit of 50 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded, got %v", err) + } + }) + + // Test that maps within limit work fine + t.Run("MapWithinLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "small_map") + buf = msgp.AppendMapHeader(buf, 3) // Within limit + buf = msgp.AppendString(buf, "a") + buf = msgp.AppendInt(buf, 1) + buf = msgp.AppendString(buf, "b") + buf = msgp.AppendInt(buf, 2) + buf = msgp.AppendString(buf, "c") + buf = msgp.AppendInt(buf, 3) + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Unexpected error for map within limit: %v", err) + } + }) +} + +func TestFixedArraysNotLimited(t *testing.T) { + // Test that fixed arrays are not subject to limits + // BigArray [1000]int should work even though 1000 > 100 (array limit) + data := UnlimitedData{} + + t.Run("FixedArray_DecodeMsg", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_array") + buf = msgp.AppendArrayHeader(buf, 1000) // Fixed array size, should not be limited + for i := 0; i < 1000; i++ { + buf = msgp.AppendInt(buf, i) + } + + reader := msgp.NewReader(bytes.NewReader(buf)) + err := data.DecodeMsg(reader) + if err != nil { + t.Errorf("Fixed arrays should not be limited, got error: %v", err) + } + }) + + t.Run("FixedArray_UnmarshalMsg", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_array") + buf = msgp.AppendArrayHeader(buf, 1000) // Fixed array size, should not be limited + for i := 0; i < 1000; i++ { + buf = msgp.AppendInt(buf, i) + } + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Fixed arrays should not be limited, got error: %v", err) + } + }) +} + +func TestSliceLimitsApplied(t *testing.T) { + // Test that dynamic slices are subject to limits + data := UnlimitedData{} + + t.Run("Slice_ExceedsLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 150) // Exceeds array limit of 100 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded for slice, got %v", err) + } + }) + + t.Run("Slice_WithinLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 50) // Within array limit of 100 + for i := 0; i < 50; i++ { + buf = msgp.AppendString(buf, "test") + } + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Unexpected error for slice within limit: %v", err) + } + }) +} + +func TestNestedArrayLimits(t *testing.T) { + // Test limits on nested arrays within maps + data := UnlimitedData{} + + t.Run("NestedArray_ExceedsLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_map") + buf = msgp.AppendMapHeader(buf, 1) // Within map limit + buf = msgp.AppendString(buf, "key") + buf = msgp.AppendArrayHeader(buf, 150) // Nested array exceeds limit of 100 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded for nested array, got %v", err) + } + }) + + t.Run("NestedArray_WithinLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_map") + buf = msgp.AppendMapHeader(buf, 1) // Within map limit + buf = msgp.AppendString(buf, "key") + buf = msgp.AppendArrayHeader(buf, 50) // Nested array within limit + for i := 0; i < 50; i++ { + buf = msgp.AppendInt(buf, i) + } + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Unexpected error for nested array within limit: %v", err) + } + }) +} + +func TestMapExceedsLimit(t *testing.T) { + data := UnlimitedData{} + + t.Run("Map_ExceedsLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_map") + buf = msgp.AppendMapHeader(buf, 60) // Exceeds map limit of 50 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded for map, got %v", err) + } + }) +} + +func TestStructLevelLimits(t *testing.T) { + // Test that the struct-level map limits are enforced + data := LimitTestData{} + + t.Run("StructMap_ExceedsLimit", func(t *testing.T) { + // Create a struct with too many fields + buf := msgp.AppendMapHeader(nil, 60) // Exceeds map limit of 50 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded for struct map, got %v", err) + } + }) +} + +func TestNormalOperationWithinLimits(t *testing.T) { + // Test that normal operation works when everything is within limits + data := LimitTestData{} + + // Create valid data + data.SmallArray = [10]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + data.LargeSlice = []byte("test data") + data.SmallMap = map[string]int{"a": 1, "b": 2, "c": 3} + + t.Run("RoundTrip_Marshal_Unmarshal", func(t *testing.T) { + // Test MarshalMsg -> UnmarshalMsg + buf, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed: %v", err) + } + + var result LimitTestData + _, err = result.UnmarshalMsg(buf) + if err != nil { + t.Fatalf("UnmarshalMsg failed: %v", err) + } + + // Verify data integrity + if result.SmallArray != data.SmallArray { + t.Errorf("SmallArray mismatch: got %v, want %v", result.SmallArray, data.SmallArray) + } + if !bytes.Equal(result.LargeSlice, data.LargeSlice) { + t.Errorf("LargeSlice mismatch: got %v, want %v", result.LargeSlice, data.LargeSlice) + } + if len(result.SmallMap) != len(data.SmallMap) { + t.Errorf("SmallMap length mismatch: got %d, want %d", len(result.SmallMap), len(data.SmallMap)) + } + }) + + t.Run("RoundTrip_Encode_Decode", func(t *testing.T) { + // Test EncodeMsg -> DecodeMsg + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err := data.EncodeMsg(writer) + if err != nil { + t.Fatalf("EncodeMsg failed: %v", err) + } + writer.Flush() + + var result LimitTestData + reader := msgp.NewReader(&buf) + err = result.DecodeMsg(reader) + if err != nil { + t.Fatalf("DecodeMsg failed: %v", err) + } + + // Verify data integrity + if result.SmallArray != data.SmallArray { + t.Errorf("SmallArray mismatch: got %v, want %v", result.SmallArray, data.SmallArray) + } + if !bytes.Equal(result.LargeSlice, data.LargeSlice) { + t.Errorf("LargeSlice mismatch: got %v, want %v", result.LargeSlice, data.LargeSlice) + } + if len(result.SmallMap) != len(data.SmallMap) { + t.Errorf("SmallMap length mismatch: got %d, want %d", len(result.SmallMap), len(data.SmallMap)) + } + }) +} + +func TestMarshalLimitEnforcement(t *testing.T) { + // Test marshal-time limit enforcement with MarshalLimitTestData + // This struct has marshal:true with arrays:30 maps:20 + + t.Run("MarshalMsg_SliceLimit", func(t *testing.T) { + data := MarshalLimitTestData{ + TestSlice: make([]string, 40), // Exceeds array limit of 30 + } + // Fill the slice + for i := range data.TestSlice { + data.TestSlice[i] = "test" + } + + _, err := data.MarshalMsg(nil) + if err == nil { + t.Error("Expected error for slice exceeding marshal limit, got nil") + } + }) + + t.Run("MarshalMsg_MapLimit", func(t *testing.T) { + data := MarshalLimitTestData{ + TestMap: make(map[string]int, 25), // Exceeds map limit of 20 + } + // Fill the map + for i := 0; i < 25; i++ { + data.TestMap[fmt.Sprintf("key%d", i)] = i + } + + _, err := data.MarshalMsg(nil) + if err == nil { + t.Error("Expected error for map exceeding marshal limit, got nil") + } + }) + + t.Run("EncodeMsg_SliceLimit", func(t *testing.T) { + data := MarshalLimitTestData{ + TestSlice: make([]string, 40), // Exceeds array limit of 30 + } + // Fill the slice + for i := range data.TestSlice { + data.TestSlice[i] = "test" + } + + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err := data.EncodeMsg(writer) + if err == nil { + t.Error("Expected error for slice exceeding marshal limit, got nil") + } + }) + + t.Run("EncodeMsg_MapLimit", func(t *testing.T) { + data := MarshalLimitTestData{ + TestMap: make(map[string]int, 25), // Exceeds map limit of 20 + } + // Fill the map + for i := 0; i < 25; i++ { + data.TestMap[fmt.Sprintf("key%d", i)] = i + } + + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err := data.EncodeMsg(writer) + if err == nil { + t.Error("Expected error for map exceeding marshal limit, got nil") + } + }) + + t.Run("MarshalWithinLimits", func(t *testing.T) { + data := MarshalLimitTestData{ + SmallArray: [5]int{1, 2, 3, 4, 5}, + TestSlice: []string{"a", "b", "c"}, // Within limit of 30 + TestMap: map[string]int{"x": 1, "y": 2}, // Within limit of 20 + } + + // Test MarshalMsg + _, err := data.MarshalMsg(nil) + if err != nil { + t.Errorf("Unexpected error for data within marshal limits: %v", err) + } + + // Test EncodeMsg + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err = data.EncodeMsg(writer) + if err != nil { + t.Errorf("Unexpected error for data within marshal limits: %v", err) + } + }) + + t.Run("FixedArraysNotLimited_Marshal", func(t *testing.T) { + // Fixed arrays should not be subject to marshal limits + data := MarshalLimitTestData{ + SmallArray: [5]int{1, 2, 3, 4, 5}, // Fixed array size + } + + _, err := data.MarshalMsg(nil) + if err != nil { + t.Errorf("Fixed arrays should not be limited during marshal, got error: %v", err) + } + }) +} diff --git a/_generated/marshal_limits.go b/_generated/marshal_limits.go new file mode 100644 index 00000000..17be7061 --- /dev/null +++ b/_generated/marshal_limits.go @@ -0,0 +1,12 @@ +//msgp:limit arrays:30 maps:20 marshal:true + +package _generated + +//go:generate msgp + +// Test structures for marshal-time limit enforcement +type MarshalLimitTestData struct { + SmallArray [5]int `msg:"small_array"` + TestSlice []string `msg:"test_slice"` + TestMap map[string]int `msg:"test_map"` +} diff --git a/gen/decode.go b/gen/decode.go index 07352c6e..f9f9049a 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -3,6 +3,7 @@ package gen import ( "fmt" "io" + "math" "strconv" "strings" ) @@ -74,10 +75,38 @@ func (d *decodeGen) assignAndCheck(name string, typ string) { d.p.wrapErrCheck(d.ctx.ArgsStr()) } +func (d *decodeGen) assignAndCheckWithArrayLimit(name string, typ string) { + if !d.p.ok() { + return + } + d.p.printf("\n%s, err = dc.Read%s()", name, typ) + d.p.wrapErrCheck(d.ctx.ArgsStr()) + if d.ctx.arrayLimit != math.MaxUint32 { + d.p.printf("\nif %s > %slimitArrays {", name, d.ctx.limitPrefix) + d.p.printf("\nerr = msgp.ErrLimitExceeded") + d.p.printf("\nreturn") + d.p.printf("\n}") + } +} + +func (d *decodeGen) assignAndCheckWithMapLimit(name string, typ string) { + if !d.p.ok() { + return + } + d.p.printf("\n%s, err = dc.Read%s()", name, typ) + d.p.wrapErrCheck(d.ctx.ArgsStr()) + if d.ctx.mapLimit != math.MaxUint32 { + d.p.printf("\nif %s > %slimitMaps {", name, d.ctx.limitPrefix) + d.p.printf("\nerr = msgp.ErrLimitExceeded") + d.p.printf("\nreturn") + d.p.printf("\n}") + } +} + func (d *decodeGen) structAsTuple(s *Struct) { sz := randIdent() d.p.declare(sz, u32) - d.assignAndCheck(sz, arrayHeader) + d.assignAndCheckWithArrayLimit(sz, arrayHeader) if s.AsVarTuple { d.p.printf("\nif %[1]s == 0 { return }", sz) } else { @@ -116,7 +145,7 @@ func (d *decodeGen) structAsMap(s *Struct) { d.needsField() sz := randIdent() d.p.declare(sz, u32) - d.assignAndCheck(sz, mapHeader) + d.assignAndCheckWithMapLimit(sz, mapHeader) oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero") if !d.ctx.clearOmitted { @@ -281,7 +310,7 @@ func (d *decodeGen) gMap(m *Map) { // resize or allocate map d.p.declare(sz, u32) - d.assignAndCheck(sz, mapHeader) + d.assignAndCheckWithMapLimit(sz, mapHeader) d.p.resizeMap(sz, m) // for element in map, read string/value @@ -305,7 +334,7 @@ func (d *decodeGen) gSlice(s *Slice) { } sz := randIdent() d.p.declare(sz, u32) - d.assignAndCheck(sz, arrayHeader) + d.assignAndCheckWithArrayLimit(sz, arrayHeader) if s.isAllowNil { d.p.resizeSliceNoNil(sz, s) } else { diff --git a/gen/encode.go b/gen/encode.go index 2b8bd7e8..5d42cdec 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -3,6 +3,7 @@ package gen import ( "fmt" "io" + "math" "strings" "github.com/tinylib/msgp/msgp" @@ -39,6 +40,26 @@ func (e *encodeGen) writeAndCheck(typ string, argfmt string, arg any) { e.p.wrapErrCheck(e.ctx.ArgsStr()) } +func (e *encodeGen) writeAndCheckWithArrayLimit(typ string, argfmt string, arg any) { + e.writeAndCheck(typ, argfmt, arg) + if e.ctx.marshalLimits && e.ctx.arrayLimit != math.MaxUint32 { + e.p.printf("\nif %s > %slimitArrays {", fmt.Sprintf(argfmt, arg), e.ctx.limitPrefix) + e.p.printf("\nerr = msgp.ErrLimitExceeded") + e.p.printf("\nreturn") + e.p.printf("\n}") + } +} + +func (e *encodeGen) writeAndCheckWithMapLimit(typ string, argfmt string, arg any) { + e.writeAndCheck(typ, argfmt, arg) + if e.ctx.marshalLimits && e.ctx.mapLimit != math.MaxUint32 { + e.p.printf("\nif %s > %slimitMaps {", fmt.Sprintf(argfmt, arg), e.ctx.limitPrefix) + e.p.printf("\nerr = msgp.ErrLimitExceeded") + e.p.printf("\nreturn") + e.p.printf("\n}") + } +} + func (e *encodeGen) fuseHook() { if len(e.fuse) > 0 { e.appendraw(e.fuse) @@ -244,7 +265,7 @@ func (e *encodeGen) gMap(m *Map) { } e.fuseHook() vname := m.Varname() - e.writeAndCheck(mapHeader, lenAsUint32, vname) + e.writeAndCheckWithMapLimit(mapHeader, lenAsUint32, vname) e.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vname) if m.Key != nil { @@ -297,7 +318,7 @@ func (e *encodeGen) gSlice(s *Slice) { return } e.fuseHook() - e.writeAndCheck(arrayHeader, lenAsUint32, s.Varname()) + e.writeAndCheckWithArrayLimit(arrayHeader, lenAsUint32, s.Varname()) setTypeParams(s.Els, s.typeParams) e.p.rangeBlock(e.ctx, s.Index, s.Varname(), e, s.Els) } diff --git a/gen/marshal.go b/gen/marshal.go index e68464ad..41fe111d 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -3,6 +3,7 @@ package gen import ( "fmt" "io" + "math" "strings" "github.com/tinylib/msgp/msgp" @@ -73,6 +74,24 @@ func (m *marshalGen) rawAppend(typ string, argfmt string, arg any) { m.p.printf("\no = msgp.Append%s(o, %s)", typ, fmt.Sprintf(argfmt, arg)) } +func (m *marshalGen) rawAppendWithArrayLimit(typ string, argfmt string, arg any) { + m.rawAppend(typ, argfmt, arg) + if m.ctx.marshalLimits && m.ctx.arrayLimit != math.MaxUint32 { + m.p.printf("\nif %s > %slimitArrays {", fmt.Sprintf(argfmt, arg), m.ctx.limitPrefix) + m.p.printf("\nreturn nil, msgp.ErrLimitExceeded") + m.p.printf("\n}") + } +} + +func (m *marshalGen) rawAppendWithMapLimit(typ string, argfmt string, arg any) { + m.rawAppend(typ, argfmt, arg) + if m.ctx.marshalLimits && m.ctx.mapLimit != math.MaxUint32 { + m.p.printf("\nif %s > %slimitMaps {", fmt.Sprintf(argfmt, arg), m.ctx.limitPrefix) + m.p.printf("\nreturn nil, msgp.ErrLimitExceeded") + m.p.printf("\n}") + } +} + func (m *marshalGen) fuseHook() { if len(m.fuse) > 0 { m.rawbytes(m.fuse) @@ -250,7 +269,7 @@ func (m *marshalGen) gMap(s *Map) { } m.fuseHook() vname := s.Varname() - m.rawAppend(mapHeader, lenAsUint32, vname) + m.rawAppendWithMapLimit(mapHeader, lenAsUint32, vname) m.p.printf("\nfor %s, %s := range %s {", s.Keyidx, s.Validx, vname) // Shim key to base type if necessary. if s.Key != nil { @@ -292,7 +311,7 @@ func (m *marshalGen) gSlice(s *Slice) { vname := s.Varname() setTypeParams(s.Els, s.typeParams) - m.rawAppend(arrayHeader, lenAsUint32, vname) + m.rawAppendWithArrayLimit(arrayHeader, lenAsUint32, vname) m.p.rangeBlock(m.ctx, s.Index, vname, m, s.Els) } diff --git a/gen/spec.go b/gen/spec.go index 4ff4fbee..83c2fcd3 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -81,6 +81,10 @@ type Printer struct { ClearOmitted bool NewTime bool AsUTC bool + ArrayLimit uint32 + MapLimit uint32 + MarshalLimits bool + LimitPrefix string } func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer { @@ -151,10 +155,14 @@ func (p *Printer) Print(e Elem) error { // hence the separate prefixes. resetIdent("zb") err := g.Execute(e, Context{ - compFloats: p.CompactFloats, - clearOmitted: p.ClearOmitted, - newTime: p.NewTime, - asUTC: p.AsUTC, + compFloats: p.CompactFloats, + clearOmitted: p.ClearOmitted, + newTime: p.NewTime, + asUTC: p.AsUTC, + arrayLimit: p.ArrayLimit, + mapLimit: p.MapLimit, + marshalLimits: p.MarshalLimits, + limitPrefix: p.LimitPrefix, }) resetIdent("za") @@ -182,11 +190,15 @@ func (c contextVar) Arg() string { } type Context struct { - path []contextItem - compFloats bool - clearOmitted bool - newTime bool - asUTC bool + path []contextItem + compFloats bool + clearOmitted bool + newTime bool + asUTC bool + arrayLimit uint32 + mapLimit uint32 + marshalLimits bool + limitPrefix string } func (c *Context) PushString(s string) { diff --git a/gen/unmarshal.go b/gen/unmarshal.go index f1d6bfce..bf338866 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -3,6 +3,7 @@ package gen import ( "fmt" "io" + "math" "strconv" "strings" ) @@ -63,6 +64,34 @@ func (u *unmarshalGen) assignAndCheck(name string, base string) { u.p.wrapErrCheck(u.ctx.ArgsStr()) } +func (u *unmarshalGen) assignAndCheckWithArrayLimit(name string, base string) { + if !u.p.ok() { + return + } + u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", name, base) + u.p.wrapErrCheck(u.ctx.ArgsStr()) + if u.ctx.arrayLimit != math.MaxUint32 { + u.p.printf("\nif %s > %slimitArrays {", name, u.ctx.limitPrefix) + u.p.printf("\nerr = msgp.ErrLimitExceeded") + u.p.printf("\nreturn") + u.p.printf("\n}") + } +} + +func (u *unmarshalGen) assignAndCheckWithMapLimit(name string, base string) { + if !u.p.ok() { + return + } + u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", name, base) + u.p.wrapErrCheck(u.ctx.ArgsStr()) + if u.ctx.mapLimit != math.MaxUint32 { + u.p.printf("\nif %s > %slimitMaps {", name, u.ctx.limitPrefix) + u.p.printf("\nerr = msgp.ErrLimitExceeded") + u.p.printf("\nreturn") + u.p.printf("\n}") + } +} + func (u *unmarshalGen) gStruct(s *Struct) { if !u.p.ok() { return @@ -137,7 +166,7 @@ func (u *unmarshalGen) mapstruct(s *Struct) { u.needsField() sz := randIdent() u.p.declare(sz, u32) - u.assignAndCheck(sz, mapHeader) + u.assignAndCheckWithMapLimit(sz, mapHeader) oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero") if !u.ctx.clearOmitted { @@ -314,7 +343,7 @@ func (u *unmarshalGen) gSlice(s *Slice) { } sz := randIdent() u.p.declare(sz, u32) - u.assignAndCheck(sz, arrayHeader) + u.assignAndCheckWithArrayLimit(sz, arrayHeader) if s.isAllowNil { u.p.resizeSliceNoNil(sz, s) } else { @@ -330,7 +359,7 @@ func (u *unmarshalGen) gMap(m *Map) { } sz := randIdent() u.p.declare(sz, u32) - u.assignAndCheck(sz, mapHeader) + u.assignAndCheckWithMapLimit(sz, mapHeader) // allocate or clear map u.p.resizeMap(sz, m) diff --git a/parse/directives.go b/parse/directives.go index dfe56099..09474a6f 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -32,6 +32,7 @@ var directives = map[string]directive{ "clearomitted": clearomitted, "newtime": newtime, "timezone": newtimezone, + "limit": limit, } // map of all recognized directives which will be applied @@ -315,3 +316,37 @@ func newtimezone(text []string, f *FileSet) error { infof("using timezone %q\n", text[1]) return nil } + +//msgp:limit arrays:n maps:n marshal:true/false +func limit(text []string, f *FileSet) (err error) { + for _, arg := range text[1:] { + arg = strings.ToLower(strings.TrimSpace(arg)) + switch { + case strings.HasPrefix(arg, "arrays:"): + limitStr := strings.TrimPrefix(arg, "arrays:") + limit, err := strconv.ParseUint(limitStr, 10, 32) + if err != nil { + return fmt.Errorf("invalid arrays limit; found %s, expected positive integer", limitStr) + } + f.ArrayLimit = uint32(limit) + case strings.HasPrefix(arg, "maps:"): + limitStr := strings.TrimPrefix(arg, "maps:") + limit, err := strconv.ParseUint(limitStr, 10, 32) + if err != nil { + return fmt.Errorf("invalid maps limit; found %s, expected positive integer", limitStr) + } + f.MapLimit = uint32(limit) + case strings.HasPrefix(arg, "marshal:"): + marshalStr := strings.TrimPrefix(arg, "marshal:") + marshal, err := strconv.ParseBool(marshalStr) + if err != nil { + return fmt.Errorf("invalid marshal option; found %s, expected 'true' or 'false'", marshalStr) + } + f.MarshalLimits = marshal + default: + return fmt.Errorf("invalid limit directive; found %s, expected 'arrays:n', 'maps:n', or 'marshal:true/false'", arg) + } + } + infof("limits - arrays:%d maps:%d marshal:%t\n", f.ArrayLimit, f.MapLimit, f.MarshalLimits) + return nil +} diff --git a/parse/getast.go b/parse/getast.go index cdc91699..e7b5443d 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -36,6 +36,10 @@ type FileSet struct { AllowMapShims bool // Allow map keys to be shimmed (default true) AllowBinMaps bool // Allow maps with binary keys to be used (default false) AutoMapShims bool // Automatically shim map keys of builtin types(default false) + ArrayLimit uint32 // Maximum array/slice size allowed during deserialization + MapLimit uint32 // Maximum map size allowed during deserialization + MarshalLimits bool // Whether to enforce limits during marshaling + LimitPrefix string // Unique prefix for limit constants to avoid collisions tagName string // tag to read field names from pointerRcv bool // generate with pointer receivers. @@ -55,6 +59,8 @@ func File(name string, unexported bool, directives []string) (*FileSet, error) { TypeInfos: make(map[string]*TypeInfo), Identities: make(map[string]gen.Elem), Directives: append([]string{}, directives...), + ArrayLimit: 4294967295, // math.MaxUint32 + MapLimit: 4294967295, // math.MaxUint32 } fset := token.NewFileSet() @@ -410,6 +416,10 @@ loop: p.ClearOmitted = fs.ClearOmitted p.NewTime = fs.NewTime p.AsUTC = fs.AsUTC + p.ArrayLimit = fs.ArrayLimit + p.MapLimit = fs.MapLimit + p.MarshalLimits = fs.MarshalLimits + p.LimitPrefix = fs.LimitPrefix } func (fs *FileSet) PrintTo(p *gen.Printer) error { diff --git a/printer/print.go b/printer/print.go index a2b48472..16971bf9 100644 --- a/printer/print.go +++ b/printer/print.go @@ -3,8 +3,11 @@ package printer import ( "bytes" "fmt" + "hash/crc32" "io" + "math" "os" + "path/filepath" "strings" "github.com/tinylib/msgp/gen" @@ -18,7 +21,7 @@ var Logf func(s string, v ...any) // of elements to the given file name and canonical // package path. func PrintFile(file string, f *parse.FileSet, mode gen.Method) error { - out, tests, err := generate(f, mode) + out, tests, err := generate(file, f, mode) if err != nil { return err } @@ -83,7 +86,7 @@ func dedupImports(imp []string) []string { return r } -func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, error) { +func generate(file string, f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, error) { outbuf := bytes.NewBuffer(make([]byte, 0, 4096)) writePkgHeader(outbuf, f.Package) @@ -99,6 +102,8 @@ func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, dedup := dedupImports(myImports) writeImportHeader(outbuf, dedup...) + writeLimitConstants(outbuf, file, f) + var testbuf *bytes.Buffer var testwr io.Writer if mode&gen.Test == gen.Test { @@ -136,3 +141,28 @@ func writeImportHeader(b *bytes.Buffer, imports ...string) { } b.WriteString(")\n\n") } + +// generateFilePrefix creates a deterministic, unique prefix for constants based on the file name +func generateFilePrefix(filename string) string { + base := filepath.Base(filename) + hash := crc32.ChecksumIEEE([]byte(base)) + return fmt.Sprintf("z%08x", hash) +} + +func writeLimitConstants(b *bytes.Buffer, file string, f *parse.FileSet) { + if f.ArrayLimit != math.MaxUint32 || f.MapLimit != math.MaxUint32 { + prefix := generateFilePrefix(file) + b.WriteString("// Size limits for msgp deserialization\n") + b.WriteString("const (\n") + if f.ArrayLimit != math.MaxUint32 { + fmt.Fprintf(b, "\t%slimitArrays = %d\n", prefix, f.ArrayLimit) + } + if f.MapLimit != math.MaxUint32 { + fmt.Fprintf(b, "\t%slimitMaps = %d\n", prefix, f.MapLimit) + } + b.WriteString(")\n\n") + + // Store the prefix in FileSet so generators can use it + f.LimitPrefix = prefix + } +} From da9b137d1a90cb900bd74d8e7e004fe3eb61d697 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Sat, 15 Nov 2025 12:10:58 +0100 Subject: [PATCH 2/2] Just use math.MaxUint32 --- parse/getast.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/parse/getast.go b/parse/getast.go index e7b5443d..7a77f755 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -5,6 +5,7 @@ import ( "go/ast" "go/parser" "go/token" + "math" "os" "reflect" "sort" @@ -59,8 +60,8 @@ func File(name string, unexported bool, directives []string) (*FileSet, error) { TypeInfos: make(map[string]*TypeInfo), Identities: make(map[string]gen.Elem), Directives: append([]string{}, directives...), - ArrayLimit: 4294967295, // math.MaxUint32 - MapLimit: 4294967295, // math.MaxUint32 + ArrayLimit: math.MaxUint32, + MapLimit: math.MaxUint32, } fset := token.NewFileSet()