Skip to content

Commit acb5d3d

Browse files
(gosec) Apply G115 fixes to x/mongo/driver/topology package
Address gosec G115 integer overflow warnings in topology: - Add SafeConvertNumeric for wire message compression operations - Add SafeConvertNumeric for server description fields (maxBsonObjectSize, etc.) - Add SafeConvertNumeric for connection pool size conversions - Use binaryutil for safe binary operations
1 parent f5bdc5c commit acb5d3d

File tree

11 files changed

+83
-33
lines changed

11 files changed

+83
-33
lines changed

x/mongo/driver/topology/CMAP_spec_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,8 @@ func runOperationInThread(t *testing.T, operation map[string]any, testInfo *test
427427
t.Fatalf("unable to find thread to wait for: %v", threadName)
428428
}
429429

430-
for {
431-
if atomic.LoadInt32(&thread.JobsCompleted) == atomic.LoadInt32(&thread.JobsAssigned) {
432-
break
433-
}
430+
for atomic.LoadInt32(&thread.JobsCompleted) != atomic.LoadInt32(&thread.JobsAssigned) {
431+
434432
}
435433
case "waitForEvent":
436434
var targetCount int

x/mongo/driver/topology/connection.go

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ import (
1313
"errors"
1414
"fmt"
1515
"io"
16+
"math"
1617
"net"
1718
"strings"
1819
"sync"
1920
"sync/atomic"
2021
"time"
2122

2223
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
24+
"go.mongodb.org/mongo-driver/v2/internal/mathutil"
2325
"go.mongodb.org/mongo-driver/v2/mongo/address"
2426
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
2527
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
@@ -226,7 +228,6 @@ func (c *connection) connect(ctx context.Context) (err error) {
226228
HTTPClient: c.config.httpClient,
227229
}
228230
tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
229-
230231
if err != nil {
231232
return ConnectionError{Wrapped: err, init: true, message: fmt.Sprintf("failed to configure TLS for %s", c.addr)}
232233
}
@@ -427,7 +428,12 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
427428

428429
func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
429430
// read the length as an int32
430-
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:]))
431+
rawSize := binary.LittleEndian.Uint32(wmSizeBytes[:])
432+
if rawSize > uint32(math.MaxInt32) {
433+
return 0, fmt.Errorf("message length exceeds int32 max: %d", rawSize)
434+
}
435+
436+
size := int32(rawSize)
431437

432438
if size < 4 {
433439
return 0, fmt.Errorf("malformed message length: %d", size)
@@ -475,7 +481,12 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
475481
// reading messages from an exhaust cursor.
476482
n, err := io.ReadFull(c.nc, sizeBuf[:])
477483
if err != nil {
478-
if l := int32(n); l == 0 && isCSOTTimeout(err) {
484+
nI32, convErr := mathutil.SafeConvertNumeric[int32](n)
485+
if convErr != nil {
486+
return nil, "incomplete read of message header", convErr
487+
}
488+
489+
if l := nI32; l == 0 && isCSOTTimeout(err) {
479490
c.awaitRemainingBytes = &l
480491
}
481492
return nil, "incomplete read of message header", err
@@ -490,7 +501,12 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
490501

491502
n, err = io.ReadFull(c.nc, dst[4:])
492503
if err != nil {
493-
remainingBytes := size - 4 - int32(n)
504+
nI32, convErr := mathutil.SafeConvertNumeric[int32](n)
505+
if convErr != nil {
506+
return dst, "incomplete read of full message", convErr
507+
}
508+
509+
remainingBytes := size - 4 - nI32
494510
if remainingBytes > 0 && isCSOTTimeout(err) {
495511
c.awaitRemainingBytes = &remainingBytes
496512
}
@@ -586,15 +602,17 @@ func (c *connection) SetOIDCTokenGenID(genID uint64) {
586602
// *connection to a Handshaker.
587603
type initConnection struct{ *connection }
588604

589-
var _ mnet.ReadWriteCloser = initConnection{}
590-
var _ mnet.Describer = initConnection{}
591-
var _ mnet.Streamer = initConnection{}
605+
var (
606+
_ mnet.ReadWriteCloser = initConnection{}
607+
_ mnet.Describer = initConnection{}
608+
_ mnet.Streamer = initConnection{}
609+
)
592610

593611
func (c initConnection) Description() description.Server {
594612
if c.connection == nil {
595613
return description.Server{}
596614
}
597-
return c.connection.desc
615+
return c.desc
598616
}
599617
func (c initConnection) Close() error { return nil }
600618
func (c initConnection) ID() string { return c.id }
@@ -606,18 +624,23 @@ func (c initConnection) LocalAddress() address.Address {
606624
}
607625
return address.Address(c.nc.LocalAddr().String())
608626
}
627+
609628
func (c initConnection) Write(ctx context.Context, wm []byte) error {
610629
return c.writeWireMessage(ctx, wm)
611630
}
631+
612632
func (c initConnection) Read(ctx context.Context) ([]byte, error) {
613633
return c.readWireMessage(ctx)
614634
}
635+
615636
func (c initConnection) SetStreaming(streaming bool) {
616637
c.setStreaming(streaming)
617638
}
639+
618640
func (c initConnection) CurrentlyStreaming() bool {
619641
return c.getCurrentlyStreaming()
620642
}
643+
621644
func (c initConnection) SupportsStreaming() bool {
622645
return c.canStream
623646
}
@@ -639,11 +662,13 @@ type Connection struct {
639662
mu sync.RWMutex
640663
}
641664

642-
var _ mnet.ReadWriteCloser = (*Connection)(nil)
643-
var _ mnet.Describer = (*Connection)(nil)
644-
var _ mnet.Compressor = (*Connection)(nil)
645-
var _ mnet.Pinner = (*Connection)(nil)
646-
var _ driver.Expirable = (*Connection)(nil)
665+
var (
666+
_ mnet.ReadWriteCloser = (*Connection)(nil)
667+
_ mnet.Describer = (*Connection)(nil)
668+
_ mnet.Compressor = (*Connection)(nil)
669+
_ mnet.Pinner = (*Connection)(nil)
670+
_ driver.Expirable = (*Connection)(nil)
671+
)
647672

648673
// WriteWireMessage handles writing a wire message to the underlying connection.
649674
func (c *Connection) Write(ctx context.Context, wm []byte) error {
@@ -684,7 +709,13 @@ func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
684709
}
685710
idx, dst := wiremessage.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
686711
dst = wiremessage.AppendCompressedOriginalOpCode(dst, origcode)
687-
dst = wiremessage.AppendCompressedUncompressedSize(dst, int32(len(rem)))
712+
713+
remI32, err := mathutil.SafeConvertNumeric[int32](len(rem))
714+
if err != nil {
715+
return nil, err
716+
}
717+
718+
dst = wiremessage.AppendCompressedUncompressedSize(dst, remI32)
688719
dst = wiremessage.AppendCompressedCompressorID(dst, c.connection.compressor)
689720
opts := driver.CompressionOpts{
690721
Compressor: c.connection.compressor,
@@ -696,7 +727,11 @@ func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
696727
return nil, err
697728
}
698729
dst = wiremessage.AppendCompressedCompressedMessage(dst, compressed)
699-
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
730+
length, err := mathutil.SafeConvertNumeric[int32](len(dst[idx:]))
731+
if err != nil {
732+
return nil, err
733+
}
734+
return bsoncore.UpdateLength(dst, idx, length), nil
700735
}
701736

702737
// Description returns the server description of the server this connection is connected to.

x/mongo/driver/topology/connection_errors_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
66

77
//go:build go1.13
8-
// +build go1.13
98

109
package topology
1110

x/mongo/driver/topology/errors.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"strings"
1717
"time"
1818

19+
"go.mongodb.org/mongo-driver/v2/internal/mathutil"
1920
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
2021
)
2122

@@ -118,13 +119,24 @@ func (w WaitQueueTimeoutError) Error() string {
118119

119120
msg := fmt.Sprintf("%s; total connections: %d, maxPoolSize: %d, ", errorMsg, w.totalConnections, w.maxPoolSize)
120121
if pinnedConnections := w.pinnedConnections; pinnedConnections != nil {
121-
openConnectionCount := uint64(w.totalConnections) -
122+
var totcalConnectionsWarning string
123+
124+
totalConnections, err := mathutil.SafeConvertNumeric[uint64](w.totalConnections)
125+
if err != nil {
126+
totcalConnectionsWarning = fmt.Sprintf("[WARNING]: totalConnections is negative (%d); this may indicate a bug in the driver. ",
127+
w.totalConnections)
128+
totalConnections = 0
129+
}
130+
131+
openConnectionCount := totalConnections -
122132
pinnedConnections.cursorConnections -
123133
pinnedConnections.transactionConnections
124-
msg += fmt.Sprintf("connections in use by cursors: %d, connections in use by transactions: %d, connections in use by other operations: %d, ",
134+
135+
msg += fmt.Sprintf("connections in use by cursors: %d, connections in use by transactions: %d, connections in use by other operations: %d%s, ",
125136
pinnedConnections.cursorConnections,
126137
pinnedConnections.transactionConnections,
127138
openConnectionCount,
139+
totcalConnectionsWarning,
128140
)
129141
}
130142
msg += fmt.Sprintf("idle connections: %d, wait duration: %s", w.availableConnections, w.waitDuration.String())

x/mongo/driver/topology/fsm.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser
125125
SetName: f.SetName,
126126
}
127127

128-
f.Topology.SessionTimeoutMinutes = serverTimeoutMinutes
128+
f.SessionTimeoutMinutes = serverTimeoutMinutes
129129

130130
if _, ok := f.findServer(s.Addr); !ok {
131131
return f.Topology, s
@@ -157,7 +157,7 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser
157157
SupportedWireVersions.Min,
158158
MinSupportedMongoDBVersion,
159159
)
160-
f.Topology.CompatibilityErr = f.compatibilityErr
160+
f.CompatibilityErr = f.compatibilityErr
161161
return f.Topology, s
162162
}
163163

@@ -169,7 +169,7 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser
169169
server.WireVersion.Min,
170170
SupportedWireVersions.Max,
171171
)
172-
f.Topology.CompatibilityErr = f.compatibilityErr
172+
f.CompatibilityErr = f.compatibilityErr
173173
return f.Topology, s
174174
}
175175
}

x/mongo/driver/topology/pool.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"context"
1111
"fmt"
1212
"io"
13+
"math"
1314
"net"
1415
"sync"
1516
"sync/atomic"
@@ -18,6 +19,7 @@ import (
1819
"go.mongodb.org/mongo-driver/v2/bson"
1920
"go.mongodb.org/mongo-driver/v2/event"
2021
"go.mongodb.org/mongo-driver/v2/internal/logger"
22+
"go.mongodb.org/mongo-driver/v2/internal/mathutil"
2123
"go.mongodb.org/mongo-driver/v2/mongo/address"
2224
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
2325
)
@@ -158,7 +160,6 @@ func logPoolMessage(pool *pool, msg string, keysAndValues ...any) {
158160
ServerHost: host,
159161
ServerPort: port,
160162
}, keysAndValues...)...)
161-
162163
}
163164

164165
type reason struct {
@@ -241,7 +242,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool {
241242
var ctx context.Context
242243
ctx, pool.cancelBackgroundCtx = context.WithCancel(context.Background())
243244

244-
for i := 0; i < int(pool.maxConnecting); i++ {
245+
for i := uint64(0); i < pool.maxConnecting; i++ {
245246
pool.backgroundDone.Add(1)
246247
go pool.createConnections(ctx, pool.backgroundDone)
247248
}
@@ -1357,7 +1358,16 @@ func (p *pool) maintain(ctx context.Context, wg *sync.WaitGroup) {
13571358
// the number of connections requested to max 10 at a time to prevent overshooting
13581359
// minPoolSize in case other checkOut() calls are requesting new connections, too.
13591360
total := p.totalConnectionCount()
1360-
n := int(p.minSize) - total - len(wantConns)
1361+
1362+
// Since this is a forced mod 10 operation, we can safely ignore overflows.
1363+
minSize, err := mathutil.SafeConvertNumeric[int](p.minSize)
1364+
if err != nil {
1365+
// Ignore overflow here because this is only used to drive pool growth
1366+
// hints.
1367+
minSize = math.MaxInt
1368+
}
1369+
1370+
n := minSize - total - len(wantConns)
13611371
if n > 10 {
13621372
n = 10
13631373
}

x/mongo/driver/topology/server_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
66

77
//go:build go1.13
8-
// +build go1.13
98

109
package topology
1110

x/mongo/driver/topology/tls_connection_source_1_16.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
66

77
//go:build !go1.17
8-
// +build !go1.17
98

109
package topology
1110

x/mongo/driver/topology/tls_connection_source_1_17.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
66

77
//go:build go1.17
8-
// +build go1.17
98

109
package topology
1110

x/mongo/driver/topology/topology.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ func (t *Topology) pollSRVRecords(hosts string) {
772772
return
773773
}
774774
topoKind := t.Description().Kind
775-
if !(topoKind == description.Unknown || topoKind == description.TopologyKindSharded) {
775+
if topoKind != description.Unknown && topoKind != description.TopologyKindSharded {
776776
break
777777
}
778778

0 commit comments

Comments
 (0)