@@ -13,13 +13,15 @@ import (
1313 "errors"
1414 "fmt"
1515 "io"
16+ "math"
1617 "net"
1718 "strings"
1819 "sync"
1920 "sync/atomic"
2021 "time"
2122
2223 "go.mongodb.org/mongo-driver/v2/internal/driverutil"
24+ "go.mongodb.org/mongo-driver/v2/internal/mathutil"
2325 "go.mongodb.org/mongo-driver/v2/mongo/address"
2426 "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
2527 "go.mongodb.org/mongo-driver/v2/x/mongo/driver"
@@ -226,7 +228,6 @@ func (c *connection) connect(ctx context.Context) (err error) {
226228 HTTPClient : c .config .httpClient ,
227229 }
228230 tlsNc , err := configureTLS (ctx , c .config .tlsConnectionSource , c .nc , c .addr , tlsConfig , ocspOpts )
229-
230231 if err != nil {
231232 return ConnectionError {Wrapped : err , init : true , message : fmt .Sprintf ("failed to configure TLS for %s" , c .addr )}
232233 }
@@ -427,7 +428,12 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
427428
428429func (c * connection ) parseWmSizeBytes (wmSizeBytes [4 ]byte ) (int32 , error ) {
429430 // read the length as an int32
430- size := int32 (binary .LittleEndian .Uint32 (wmSizeBytes [:]))
431+ rawSize := binary .LittleEndian .Uint32 (wmSizeBytes [:])
432+ if rawSize > uint32 (math .MaxInt32 ) {
433+ return 0 , fmt .Errorf ("message length exceeds int32 max: %d" , rawSize )
434+ }
435+
436+ size := int32 (rawSize )
431437
432438 if size < 4 {
433439 return 0 , fmt .Errorf ("malformed message length: %d" , size )
@@ -475,7 +481,12 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
475481 // reading messages from an exhaust cursor.
476482 n , err := io .ReadFull (c .nc , sizeBuf [:])
477483 if err != nil {
478- if l := int32 (n ); l == 0 && isCSOTTimeout (err ) {
484+ nI32 , convErr := mathutil.SafeConvertNumeric [int32 ](n )
485+ if convErr != nil {
486+ return nil , "incomplete read of message header" , convErr
487+ }
488+
489+ if l := nI32 ; l == 0 && isCSOTTimeout (err ) {
479490 c .awaitRemainingBytes = & l
480491 }
481492 return nil , "incomplete read of message header" , err
@@ -490,7 +501,12 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
490501
491502 n , err = io .ReadFull (c .nc , dst [4 :])
492503 if err != nil {
493- remainingBytes := size - 4 - int32 (n )
504+ nI32 , convErr := mathutil.SafeConvertNumeric [int32 ](n )
505+ if convErr != nil {
506+ return dst , "incomplete read of full message" , convErr
507+ }
508+
509+ remainingBytes := size - 4 - nI32
494510 if remainingBytes > 0 && isCSOTTimeout (err ) {
495511 c .awaitRemainingBytes = & remainingBytes
496512 }
@@ -586,15 +602,17 @@ func (c *connection) SetOIDCTokenGenID(genID uint64) {
586602// *connection to a Handshaker.
587603type initConnection struct { * connection }
588604
589- var _ mnet.ReadWriteCloser = initConnection {}
590- var _ mnet.Describer = initConnection {}
591- var _ mnet.Streamer = initConnection {}
605+ var (
606+ _ mnet.ReadWriteCloser = initConnection {}
607+ _ mnet.Describer = initConnection {}
608+ _ mnet.Streamer = initConnection {}
609+ )
592610
593611func (c initConnection ) Description () description.Server {
594612 if c .connection == nil {
595613 return description.Server {}
596614 }
597- return c .connection . desc
615+ return c .desc
598616}
599617func (c initConnection ) Close () error { return nil }
600618func (c initConnection ) ID () string { return c .id }
@@ -606,18 +624,23 @@ func (c initConnection) LocalAddress() address.Address {
606624 }
607625 return address .Address (c .nc .LocalAddr ().String ())
608626}
627+
609628func (c initConnection ) Write (ctx context.Context , wm []byte ) error {
610629 return c .writeWireMessage (ctx , wm )
611630}
631+
612632func (c initConnection ) Read (ctx context.Context ) ([]byte , error ) {
613633 return c .readWireMessage (ctx )
614634}
635+
615636func (c initConnection ) SetStreaming (streaming bool ) {
616637 c .setStreaming (streaming )
617638}
639+
618640func (c initConnection ) CurrentlyStreaming () bool {
619641 return c .getCurrentlyStreaming ()
620642}
643+
621644func (c initConnection ) SupportsStreaming () bool {
622645 return c .canStream
623646}
@@ -639,11 +662,13 @@ type Connection struct {
639662 mu sync.RWMutex
640663}
641664
642- var _ mnet.ReadWriteCloser = (* Connection )(nil )
643- var _ mnet.Describer = (* Connection )(nil )
644- var _ mnet.Compressor = (* Connection )(nil )
645- var _ mnet.Pinner = (* Connection )(nil )
646- var _ driver.Expirable = (* Connection )(nil )
665+ var (
666+ _ mnet.ReadWriteCloser = (* Connection )(nil )
667+ _ mnet.Describer = (* Connection )(nil )
668+ _ mnet.Compressor = (* Connection )(nil )
669+ _ mnet.Pinner = (* Connection )(nil )
670+ _ driver.Expirable = (* Connection )(nil )
671+ )
647672
648673// WriteWireMessage handles writing a wire message to the underlying connection.
649674func (c * Connection ) Write (ctx context.Context , wm []byte ) error {
@@ -684,7 +709,13 @@ func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
684709 }
685710 idx , dst := wiremessage .AppendHeaderStart (dst , reqid , respto , wiremessage .OpCompressed )
686711 dst = wiremessage .AppendCompressedOriginalOpCode (dst , origcode )
687- dst = wiremessage .AppendCompressedUncompressedSize (dst , int32 (len (rem )))
712+
713+ remI32 , err := mathutil.SafeConvertNumeric [int32 ](len (rem ))
714+ if err != nil {
715+ return nil , err
716+ }
717+
718+ dst = wiremessage .AppendCompressedUncompressedSize (dst , remI32 )
688719 dst = wiremessage .AppendCompressedCompressorID (dst , c .connection .compressor )
689720 opts := driver.CompressionOpts {
690721 Compressor : c .connection .compressor ,
@@ -696,7 +727,11 @@ func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
696727 return nil , err
697728 }
698729 dst = wiremessage .AppendCompressedCompressedMessage (dst , compressed )
699- return bsoncore .UpdateLength (dst , idx , int32 (len (dst [idx :]))), nil
730+ length , err := mathutil.SafeConvertNumeric [int32 ](len (dst [idx :]))
731+ if err != nil {
732+ return nil , err
733+ }
734+ return bsoncore .UpdateLength (dst , idx , length ), nil
700735}
701736
702737// Description returns the server description of the server this connection is connected to.
0 commit comments