diff --git a/wire.go b/wire.go index 1fa4087..72d7ab8 100644 --- a/wire.go +++ b/wire.go @@ -58,6 +58,15 @@ func ParseRequest(line string) (Request, error) { cmd := strings.ToUpper(fields[0]) rt := strings.ToUpper(fields[1]) + // Cap fields to the maximum valid count for each command+type. + // This prevents unbounded allocation from strings.Fields on + // pathological input (in-process protection; the transport layer + // caps message size independently). + maxFields := maxRequestFieldsFor(cmd, rt) + if maxFields > 0 && len(fields) > maxFields { + return Request{}, fmt.Errorf("too many fields (%d > %d)", len(fields), maxFields) + } + switch cmd { case "QUERY": switch rt { @@ -126,6 +135,36 @@ func ParseRequest(line string) (Request, error) { } } +// maxRequestFieldsFor returns the maximum number of whitespace-separated +// fields that are valid for the given command and record type. Returns 0 +// for unknown commands/types (which will be rejected later by the switch). +func maxRequestFieldsFor(cmd, rt string) int { + switch cmd { + case "QUERY": + switch rt { + case "A", "N": + return 3 // QUERY A/N + case "S": + return 4 // QUERY S + default: + return 3 // unknown type, at most "QUERY X ..." + } + case "REGISTER": + switch rt { + case "A": + return 4 // REGISTER A
+ case "N": + return 4 // REGISTER N + case "S": + return 6 // REGISTER S
+ default: + return 3 // unknown type, at most "REGISTER X ..." + } + default: + return 0 // unknown command, check happens later + } +} + // FormatRequest serializes a request to wire format. func FormatRequest(r Request) string { switch r.RecordType { diff --git a/zz_fuzz_nameserver_test.go b/zz_fuzz_nameserver_test.go index 9356b47..3fd64c7 100644 --- a/zz_fuzz_nameserver_test.go +++ b/zz_fuzz_nameserver_test.go @@ -146,13 +146,19 @@ func TestParseRequestUnknownRecordType(t *testing.T) { func TestParseRequestExtraFields(t *testing.T) { t.Parallel() - // Extra fields should be ignored (or at least not crash) - req, err := nameserver.ParseRequest("QUERY A myhost extra1 extra2") - if err != nil { - t.Fatalf("ParseRequest with extra fields: %v", err) - } - if req.Name != "myhost" { - t.Fatalf("expected name 'myhost', got %q", req.Name) + // Extra fields exceeding max for the command/type should be rejected. + cases := []string{ + "QUERY A myhost extra1", // 4 fields, QUERY A max=3 + "QUERY N mynet extra", // 4 fields, QUERY N max=3 + "QUERY S 1 80 extra", // 5 fields, QUERY S max=4 + "REGISTER A myhost 0:1 extra", // 5 fields, REGISTER A max=4 + "REGISTER N mynet 1 extra", // 5 fields, REGISTER N max=4 + "REGISTER S svc 0:1 1 80 extra", // 7 fields, REGISTER S max=6 + } + for _, tc := range cases { + if _, err := nameserver.ParseRequest(tc); err == nil { + t.Errorf("ParseRequest(%q): expected error, got nil", tc) + } } }