6565 from array import array as _array
6666 from mmap import mmap as _mmap
6767
68+ import numpy as np
69+ import numpy .typing as npt
70+
6871
6972class UuidRepresentation :
7073 UNSPECIFIED = 0
@@ -234,13 +237,20 @@ class BinaryVector:
234237
235238 __slots__ = ("data" , "dtype" , "padding" )
236239
237- def __init__ (self , data : Sequence [float | int ], dtype : BinaryVectorDtype , padding : int = 0 ):
240+ def __init__ (
241+ self ,
242+ data : Union [Sequence [float | int ], npt .NDArray [np .number ]],
243+ dtype : BinaryVectorDtype ,
244+ padding : int = 0 ,
245+ ):
238246 """
239247 :param data: Sequence of numbers representing the mathematical vector.
240248 :param dtype: The data type stored in binary
241249 :param padding: The number of bits in the final byte that are to be ignored
242250 when a vector element's size is less than a byte
243251 and the length of the vector is not a multiple of 8.
252+ (Padding is equivalent to a negative value of `count` in
253+ `numpy.unpackbits <https://numpy.org/doc/stable/reference/generated/numpy.unpackbits.html>`_)
244254 """
245255 self .data = data
246256 self .dtype = dtype
@@ -424,10 +434,20 @@ def from_vector(
424434 ) -> Binary :
425435 ...
426436
437+ @classmethod
438+ @overload
439+ def from_vector (
440+ cls : Type [Binary ],
441+ vector : npt .NDArray [np .number ],
442+ dtype : BinaryVectorDtype ,
443+ padding : int = 0 ,
444+ ) -> Binary :
445+ ...
446+
427447 @classmethod
428448 def from_vector (
429449 cls : Type [Binary ],
430- vector : Union [BinaryVector , list [int ], list [float ]],
450+ vector : Union [BinaryVector , list [int ], list [float ], npt . NDArray [ np . number ] ],
431451 dtype : Optional [BinaryVectorDtype ] = None ,
432452 padding : Optional [int ] = None ,
433453 ) -> Binary :
@@ -459,34 +479,72 @@ def from_vector(
459479 vector = vector .data # type: ignore
460480
461481 padding = 0 if padding is None else padding
462- if dtype == BinaryVectorDtype .INT8 : # pack ints in [-128, 127] as signed int8
463- format_str = "b"
464- if padding :
465- raise ValueError (f"padding does not apply to { dtype = } " )
466- elif dtype == BinaryVectorDtype .PACKED_BIT : # pack ints in [0, 255] as unsigned uint8
467- format_str = "B"
468- if 0 <= padding > 7 :
469- raise ValueError (f"{ padding = } . It must be in [0,1, ..7]." )
470- if padding and not vector :
471- raise ValueError ("Empty vector with non-zero padding." )
472- elif dtype == BinaryVectorDtype .FLOAT32 : # pack floats as float32
473- format_str = "f"
474- if padding :
475- raise ValueError (f"padding does not apply to { dtype = } " )
476- else :
477- raise NotImplementedError ("%s not yet supported" % dtype )
478-
482+ if not isinstance (dtype , BinaryVectorDtype ):
483+ raise TypeError (
484+ "dtype must be a bson.BinaryVectorDtype of BinaryVectorDType.INT8, PACKED_BIT, FLOAT32"
485+ )
479486 metadata = struct .pack ("<sB" , dtype .value , padding )
480- data = struct .pack (f"<{ len (vector )} { format_str } " , * vector ) # type: ignore
487+
488+ if isinstance (vector , list ):
489+ if dtype == BinaryVectorDtype .INT8 : # pack ints in [-128, 127] as signed int8
490+ format_str = "b"
491+ if padding :
492+ raise ValueError (f"padding does not apply to { dtype = } " )
493+ elif dtype == BinaryVectorDtype .PACKED_BIT : # pack ints in [0, 255] as unsigned uint8
494+ format_str = "B"
495+ if 0 <= padding > 7 :
496+ raise ValueError (f"{ padding = } . It must be in [0,1, ..7]." )
497+ if padding and not vector :
498+ raise ValueError ("Empty vector with non-zero padding." )
499+ elif dtype == BinaryVectorDtype .FLOAT32 : # pack floats as float32
500+ format_str = "f"
501+ if padding :
502+ raise ValueError (f"padding does not apply to { dtype = } " )
503+ else :
504+ raise NotImplementedError ("%s not yet supported" % dtype )
505+ data = struct .pack (f"<{ len (vector )} { format_str } " , * vector )
506+ else : # vector is numpy array or incorrect type.
507+ try :
508+ import numpy as np
509+ except ImportError as exc :
510+ raise ImportError (
511+ "Failed to create binary from vector. Check type. If numpy array, numpy must be installed."
512+ ) from exc
513+ if not isinstance (vector , np .ndarray ):
514+ raise TypeError (
515+ "Could not create Binary. Vector must be a BinaryVector, list[int], list[float] or numpy ndarray."
516+ )
517+ if vector .ndim != 1 :
518+ raise ValueError (
519+ "from_numpy_vector only supports 1D arrays as it creates a single vector."
520+ )
521+
522+ if dtype == BinaryVectorDtype .FLOAT32 :
523+ vector = vector .astype (np .dtype ("float32" ), copy = False )
524+ elif dtype == BinaryVectorDtype .INT8 :
525+ if vector .min () >= - 128 and vector .max () <= 127 :
526+ vector = vector .astype (np .dtype ("int8" ), copy = False )
527+ else :
528+ raise ValueError ("Values found outside INT8 range." )
529+ elif dtype == BinaryVectorDtype .PACKED_BIT :
530+ if vector .min () >= 0 and vector .max () <= 127 :
531+ vector = vector .astype (np .dtype ("uint8" ), copy = False )
532+ else :
533+ raise ValueError ("Values found outside UINT8 range." )
534+ else :
535+ raise NotImplementedError ("%s not yet supported" % dtype )
536+ data = vector .tobytes ()
537+
481538 if padding and len (vector ) and not (data [- 1 ] & ((1 << padding ) - 1 )) == 0 :
482539 raise ValueError (
483540 "Vector has a padding P, but bits in the final byte lower than P are non-zero. They must be zero."
484541 )
485542 return cls (metadata + data , subtype = VECTOR_SUBTYPE )
486543
487- def as_vector (self ) -> BinaryVector :
488- """From the Binary, create a list of numbers, along with dtype and padding.
544+ def as_vector (self , return_numpy : bool = False ) -> BinaryVector :
545+ """From the Binary, create a list or 1-d numpy array of numbers, along with dtype and padding.
489546
547+ :param return_numpy: If True, BinaryVector.data will be a one-dimensional numpy array. By default, it is a list.
490548 :return: BinaryVector
491549
492550 .. versionadded:: 4.10
@@ -495,54 +553,84 @@ def as_vector(self) -> BinaryVector:
495553 if self .subtype != VECTOR_SUBTYPE :
496554 raise ValueError (f"Cannot decode subtype { self .subtype } as a vector" )
497555
498- position = 0
499- dtype , padding = struct .unpack_from ("<sB" , self , position )
500- position += 2
556+ dtype , padding = struct .unpack_from ("<sB" , self )
501557 dtype = BinaryVectorDtype (dtype )
502- n_values = len (self ) - position
558+ offset = 2
559+ n_bytes = len (self ) - offset
503560
504561 if padding and dtype != BinaryVectorDtype .PACKED_BIT :
505562 raise ValueError (
506563 f"Corrupt data. Padding ({ padding } ) must be 0 for all but PACKED_BIT dtypes. ({ dtype = } )"
507564 )
508565
509- if dtype == BinaryVectorDtype .INT8 :
510- dtype_format = "b"
511- format_string = f"<{ n_values } { dtype_format } "
512- vector = list (struct .unpack_from (format_string , self , position ))
513- return BinaryVector (vector , dtype , padding )
514-
515- elif dtype == BinaryVectorDtype .FLOAT32 :
516- n_bytes = len (self ) - position
517- n_values = n_bytes // 4
518- if n_bytes % 4 :
519- raise ValueError (
520- "Corrupt data. N bytes for a float32 vector must be a multiple of 4."
521- )
522- dtype_format = "f"
523- format_string = f"<{ n_values } { dtype_format } "
524- vector = list (struct .unpack_from (format_string , self , position ))
525- return BinaryVector (vector , dtype , padding )
526-
527- elif dtype == BinaryVectorDtype .PACKED_BIT :
528- # data packed as uint8
529- if padding and not n_values :
530- raise ValueError ("Corrupt data. Vector has a padding P, but no data." )
531- if padding > 7 or padding < 0 :
532- raise ValueError (f"Corrupt data. Padding ({ padding } ) must be between 0 and 7." )
533- dtype_format = "B"
534- format_string = f"<{ n_values } { dtype_format } "
535- unpacked_uint8s = list (struct .unpack_from (format_string , self , position ))
536- if padding and n_values and unpacked_uint8s [- 1 ] & (1 << padding ) - 1 != 0 :
537- warnings .warn (
538- "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero." ,
539- DeprecationWarning ,
540- stacklevel = 2 ,
541- )
542- return BinaryVector (unpacked_uint8s , dtype , padding )
543-
544- else :
545- raise NotImplementedError ("Binary Vector dtype %s not yet supported" % dtype .name )
566+ if not return_numpy :
567+ if dtype == BinaryVectorDtype .INT8 :
568+ dtype_format = "b"
569+ format_string = f"<{ n_bytes } { dtype_format } "
570+ vector = list (struct .unpack_from (format_string , self , offset ))
571+ return BinaryVector (vector , dtype , padding )
572+
573+ elif dtype == BinaryVectorDtype .FLOAT32 :
574+ n_values = n_bytes // 4
575+ if n_bytes % 4 :
576+ raise ValueError (
577+ "Corrupt data. N bytes for a float32 vector must be a multiple of 4."
578+ )
579+ dtype_format = "f"
580+ format_string = f"<{ n_values } { dtype_format } "
581+ vector = list (struct .unpack_from (format_string , self , offset ))
582+ return BinaryVector (vector , dtype , padding )
583+
584+ elif dtype == BinaryVectorDtype .PACKED_BIT :
585+ # data packed as uint8
586+ if padding and not n_bytes :
587+ raise ValueError ("Corrupt data. Vector has a padding P, but no data." )
588+ if padding > 7 or padding < 0 :
589+ raise ValueError (f"Corrupt data. Padding ({ padding } ) must be between 0 and 7." )
590+ dtype_format = "B"
591+ format_string = f"<{ n_bytes } { dtype_format } "
592+ unpacked_uint8s = list (struct .unpack_from (format_string , self , offset ))
593+ if padding and n_bytes and unpacked_uint8s [- 1 ] & (1 << padding ) - 1 != 0 :
594+ warnings .warn (
595+ "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero." ,
596+ DeprecationWarning ,
597+ stacklevel = 2 ,
598+ )
599+ return BinaryVector (unpacked_uint8s , dtype , padding )
600+
601+ else :
602+ raise NotImplementedError ("Binary Vector dtype %s not yet supported" % dtype .name )
603+ else : # create a numpy array
604+ try :
605+ import numpy as np
606+ except ImportError as exc :
607+ raise ImportError (
608+ "Converting binary to numpy.ndarray requires numpy to be installed."
609+ ) from exc
610+ if dtype == BinaryVectorDtype .INT8 :
611+ data = np .frombuffer (self [offset :], dtype = "int8" )
612+ elif dtype == BinaryVectorDtype .FLOAT32 :
613+ if n_bytes % 4 :
614+ raise ValueError (
615+ "Corrupt data. N bytes for a float32 vector must be a multiple of 4."
616+ )
617+ data = np .frombuffer (self [offset :], dtype = "float32" )
618+ elif dtype == BinaryVectorDtype .PACKED_BIT :
619+ # data packed as uint8
620+ if padding and not n_bytes :
621+ raise ValueError ("Corrupt data. Vector has a padding P, but no data." )
622+ if padding > 7 or padding < 0 :
623+ raise ValueError (f"Corrupt data. Padding ({ padding } ) must be between 0 and 7." )
624+ data = np .frombuffer (self [offset :], dtype = "uint8" )
625+ if padding and np .unpackbits (data [- 1 ])[- padding :].sum () > 0 :
626+ warnings .warn (
627+ "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero." ,
628+ DeprecationWarning ,
629+ stacklevel = 2 ,
630+ )
631+ else :
632+ raise NotImplementedError ("Binary Vector dtype %s not yet supported" % dtype .name )
633+ return BinaryVector (data , dtype , padding )
546634
547635 @property
548636 def subtype (self ) -> int :
0 commit comments