From b5ff6b796ea8d485410c9a036a954a7acf9300cb Mon Sep 17 00:00:00 2001 From: Maksim Zaikin Date: Wed, 11 Feb 2026 11:00:22 +0000 Subject: [PATCH] Implement quantization for Decimal type when encode --- docs/supported-types.rst | 36 +++++ src/msgspec/_core.c | 276 ++++++++++++++++++++++++++++++++++++-- src/msgspec/json.pyi | 25 ++++ src/msgspec/msgpack.pyi | 26 ++++ tests/unit/test_common.py | 149 ++++++++++++++++++++ 5 files changed, 502 insertions(+), 10 deletions(-) diff --git a/docs/supported-types.rst b/docs/supported-types.rst index 59f0fefd..d6a36b49 100644 --- a/docs/supported-types.rst +++ b/docs/supported-types.rst @@ -602,6 +602,42 @@ numbers by creating a ``Encoder`` and specifying ``decimal_format='number'``. >>> encoder.encode(x) b'1.2345' +You may also optionally specify a target precision for rounding +Decimal values before encoding using the ``decimal_quantize`` parameter. +If provided, all `decimal.Decimal`` values will be quantized (rounded) +to the same scale as this value using the `Decimal.quantize()` method. +This is useful for ensuring consistent precision of decimal values +during serialization, particularly in financial or monetary contexts. + +.. code-block:: python + + >>> import decimal + + >>> encoder = msgspec.json.Encoder(decimal_quantize=decimal.Decimal("0.00")) + + >>> encoder.encode(decimal.Decimal("1.23456789")) + b'"1.23"' + +The optional ``decimal_rounding`` parameter allows you to specify +the rounding mode to use when quantizing Decimal values. It accepts +one of the standard rounding mode strings from Python's decimal module, +such as ``'ROUND_DOWN'`` , ``'ROUND_HALF_UP'`` , etc. If not specified, +the default rounding mode is used (``'ROUND_HALF_EVEN'``). + +.. code-block:: python + + >>> encoder = msgspec.json.Encoder( + ... decimal_quantize=decimal.Decimal("0.00"), + ... decimal_rounding=decimal.ROUND_UP, + ... ) + + >>> encoder.encode(decimal.Decimal("1.235")) # Rounded up to two decimal places + b'"1.24"' + +.. note:: + + This parameter has no effect unless ``decimal_quantize`` is also specified. + This setting is not yet supported for YAML or TOML - if this option is important for you please `open an issue`_. diff --git a/src/msgspec/_core.c b/src/msgspec/_core.c index 320ebca7..1cef87d2 100644 --- a/src/msgspec/_core.c +++ b/src/msgspec/_core.c @@ -9470,12 +9470,26 @@ enum uuid_format { UUID_FORMAT_BYTES = 2, }; +enum decimal_rounding { + ROUND_DEFAULT = 0, + ROUND_DOWN = 1, + ROUND_HALF_UP = 2, + ROUND_HALF_EVEN = 3, + ROUND_CEILING = 4, + ROUND_FLOOR = 5, + ROUND_UP = 6, + ROUND_HALF_DOWN = 7, + ROUND_05UP = 8, +}; + typedef struct EncoderState { MsgspecState *mod; /* module reference */ PyObject *enc_hook; /* `enc_hook` callback */ enum decimal_format decimal_format; enum uuid_format uuid_format; enum order_mode order; + PyObject *decimal_quantize; + enum decimal_rounding decimal_rounding; char* (*resize_buffer)(PyObject**, Py_ssize_t); /* callback for resizing buffer */ char *output_buffer_raw; /* raw pointer to output_buffer internal buffer */ @@ -9491,6 +9505,8 @@ typedef struct Encoder { enum decimal_format decimal_format; enum uuid_format uuid_format; enum order_mode order; + PyObject *decimal_quantize; + enum decimal_rounding decimal_rounding; } Encoder; static PyTypeObject Encoder_Type; @@ -9545,13 +9561,18 @@ ms_write(EncoderState *self, const char *s, Py_ssize_t n) static int Encoder_init(Encoder *self, PyObject *args, PyObject *kwds) { - char *kwlist[] = {"enc_hook", "decimal_format", "uuid_format", "order", NULL}; - PyObject *enc_hook = NULL, *decimal_format = NULL, *uuid_format = NULL, *order = NULL; + char *kwlist[] = { + "enc_hook", "decimal_format", "uuid_format", "order", + "decimal_quantize", "decimal_rounding", NULL, + }; + PyObject *enc_hook = NULL, *decimal_format = NULL, *uuid_format = NULL, *order = NULL, \ + *decimal_quantize = NULL, *decimal_rounding = NULL; if ( !PyArg_ParseTupleAndKeywords( - args, kwds, "|$OOOO", kwlist, - &enc_hook, &decimal_format, &uuid_format, &order + args, kwds, "|$OOOOOO", kwlist, + &enc_hook, &decimal_format, &uuid_format, &order, + &decimal_quantize, &decimal_rounding ) ) { return -1; @@ -9631,7 +9652,74 @@ Encoder_init(Encoder *self, PyObject *args, PyObject *kwds) if (self->order == ORDER_INVALID) return -1; self->mod = msgspec_get_global_state(); + + /* Process decimal quantize */ + if (decimal_quantize == Py_None) { + decimal_quantize = NULL; + } + if (decimal_quantize != NULL) { + PyTypeObject *type = Py_TYPE(decimal_quantize); + if (type != (PyTypeObject *)(self->mod->DecimalType)) { + PyErr_SetString(PyExc_TypeError, "`decimal_quantize` must be a Decimal"); + return -1; + } + Py_XDECREF(type); + Py_INCREF(decimal_quantize); + } + + /* Process decimal rounding */ + if (decimal_rounding == NULL || decimal_rounding == Py_None) { + self->decimal_rounding = ROUND_DEFAULT; + } + else { + bool ok = false; + if (PyUnicode_CheckExact(decimal_rounding)) { + if (PyUnicode_CompareWithASCIIString(decimal_rounding, "ROUND_DOWN") == 0) { + self->decimal_rounding = ROUND_DOWN; + ok = true; + } + else if (PyUnicode_CompareWithASCIIString(decimal_rounding, "ROUND_HALF_UP") == 0) { + self->decimal_rounding = ROUND_HALF_UP; + ok = true; + } + else if (PyUnicode_CompareWithASCIIString(decimal_rounding, "ROUND_HALF_EVEN") == 0) { + self->decimal_rounding = ROUND_HALF_EVEN; + ok = true; + } + else if (PyUnicode_CompareWithASCIIString(decimal_rounding, "ROUND_CEILING") == 0) { + self->decimal_rounding = ROUND_CEILING; + ok = true; + } + else if (PyUnicode_CompareWithASCIIString(decimal_rounding, "ROUND_FLOOR") == 0) { + self->decimal_rounding = ROUND_FLOOR; + ok = true; + } + else if (PyUnicode_CompareWithASCIIString(decimal_rounding, "ROUND_UP") == 0) { + self->decimal_rounding = ROUND_UP; + ok = true; + } + else if (PyUnicode_CompareWithASCIIString(decimal_rounding, "ROUND_HALF_DOWN") == 0) { + self->decimal_rounding = ROUND_HALF_DOWN; + ok = true; + } + else if (PyUnicode_CompareWithASCIIString(decimal_rounding, "ROUND_05UP") == 0) { + self->decimal_rounding = ROUND_05UP; + ok = true; + } + } + if (!ok) { + PyErr_Format( + PyExc_ValueError, + "`decimal_rounding` must be 'ROUND_DOWN', 'ROUND_HALF_UP', 'ROUND_HALF_EVEN', 'ROUND_CEILING', \ + 'ROUND_FLOOR', 'ROUND_UP', 'ROUND_HALF_DOWN', 'ROUND_05UP', got %R", + decimal_rounding + ); + return -1; + } + } + self->enc_hook = enc_hook; + self->decimal_quantize = decimal_quantize; return 0; } @@ -9639,6 +9727,7 @@ static int Encoder_traverse(Encoder *self, visitproc visit, void *arg) { Py_VISIT(self->enc_hook); + Py_VISIT(self->decimal_quantize); return 0; } @@ -9646,6 +9735,7 @@ static int Encoder_clear(Encoder *self) { Py_CLEAR(self->enc_hook); + Py_CLEAR(self->decimal_quantize); return 0; } @@ -9722,6 +9812,8 @@ encoder_encode_into_common( .decimal_format = self->decimal_format, .uuid_format = self->uuid_format, .order = self->order, + .decimal_quantize = self->decimal_quantize, + .decimal_rounding = self->decimal_rounding, .output_buffer = buf, .output_buffer_raw = PyByteArray_AS_STRING(buf), .output_len = offset, @@ -9769,6 +9861,8 @@ encoder_encode_common( .decimal_format = self->decimal_format, .uuid_format = self->uuid_format, .order = self->order, + .decimal_quantize = self->decimal_quantize, + .decimal_rounding = self->decimal_rounding, .output_len = 0, .max_output_len = ENC_INIT_BUFSIZE, .resize_buffer = &ms_resize_bytes @@ -9826,6 +9920,8 @@ encode_common( .decimal_format = DECIMAL_FORMAT_STRING, .uuid_format = UUID_FORMAT_CANONICAL, .output_len = 0, + .decimal_quantize = NULL, + .decimal_rounding = ROUND_DEFAULT, .max_output_len = ENC_INIT_BUFSIZE, .resize_buffer = &ms_resize_bytes }; @@ -9884,8 +9980,52 @@ Encoder_order(Encoder *self, void *closure) { } } +static PyObject* +Encoder_decimal_rounding(Encoder *self, void *closure) { + if (self->decimal_rounding == ROUND_DOWN) { + return PyUnicode_InternFromString("ROUND_DOWN"); + } + else if (self->decimal_rounding == ROUND_HALF_UP) { + return PyUnicode_InternFromString("ROUND_HALF_UP"); + } + else if (self->decimal_rounding == ROUND_HALF_EVEN) { + return PyUnicode_InternFromString("ROUND_HALF_EVEN"); + } + else if (self->decimal_rounding == ROUND_CEILING) { + return PyUnicode_InternFromString("ROUND_CEILING"); + } + else if (self->decimal_rounding == ROUND_FLOOR) { + return PyUnicode_InternFromString("ROUND_FLOOR"); + } + else if (self->decimal_rounding == ROUND_UP) { + return PyUnicode_InternFromString("ROUND_UP"); + } + else if (self->decimal_rounding == ROUND_HALF_DOWN) { + return PyUnicode_InternFromString("ROUND_HALF_DOWN"); + } + else if (self->decimal_rounding == ROUND_05UP) { + return PyUnicode_InternFromString("ROUND_05UP"); + } + else { + Py_RETURN_NONE; + } +} + +static PyObject* +Encoder_decimal_quantize(Encoder *self, void *closure) { + if (self->decimal_quantize != NULL) { + Py_INCREF(self->decimal_quantize); + return self->decimal_quantize; + } + else { + Py_RETURN_NONE; + } +} + static PyGetSetDef Encoder_getset[] = { {"decimal_format", (getter) Encoder_decimal_format, NULL, NULL, NULL}, + {"decimal_rounding", (getter) Encoder_decimal_rounding, NULL, NULL, NULL}, + {"decimal_quantize", (getter) Encoder_decimal_quantize, NULL, NULL, NULL}, {"uuid_format", (getter) Encoder_uuid_format, NULL, NULL, NULL}, {"order", (getter) Encoder_order, NULL, NULL, NULL}, {NULL}, @@ -11717,6 +11857,76 @@ ms_decode_uuid_from_bytes(const char *buf, Py_ssize_t size, PathNode *path) { * Decimal Utilities * *************************************************************************/ + +static PyObject* +quantize_decimal_obj(EncoderState *self, PyObject *obj) +{ + if (self->decimal_quantize == NULL) { + return NULL; + } + + PyObject *quantize_method = PyObject_GetAttrString(obj, "quantize"); + if (quantize_method == NULL) { + return NULL; + } + + // Create rounding constant as string object + PyObject *rounding_const_obj = NULL; + switch (self->decimal_rounding) { + case ROUND_DEFAULT: + rounding_const_obj = Py_None; + break; + case ROUND_DOWN: + rounding_const_obj = PyUnicode_InternFromString("ROUND_DOWN"); + break; + case ROUND_HALF_UP: + rounding_const_obj = PyUnicode_InternFromString("ROUND_HALF_UP"); + break; + case ROUND_HALF_EVEN: + rounding_const_obj = PyUnicode_InternFromString("ROUND_HALF_EVEN"); + break; + case ROUND_CEILING: + rounding_const_obj = PyUnicode_InternFromString("ROUND_CEILING"); + break; + case ROUND_FLOOR: + rounding_const_obj = PyUnicode_InternFromString("ROUND_FLOOR"); + break; + case ROUND_UP: + rounding_const_obj = PyUnicode_InternFromString("ROUND_UP"); + break; + case ROUND_HALF_DOWN: + rounding_const_obj = PyUnicode_InternFromString("ROUND_HALF_DOWN"); + break; + case ROUND_05UP: + rounding_const_obj = PyUnicode_InternFromString("ROUND_05UP"); + break; + default: + rounding_const_obj = NULL; + break; + } + + if (rounding_const_obj == NULL) { + Py_DECREF(quantize_method); + return NULL; + } + + // Call the quantize method with decimal_quantize as Decimal object + PyObject *quantized_obj = PyObject_CallFunctionObjArgs( + quantize_method, + self->decimal_quantize, + rounding_const_obj, + NULL + ); + + Py_DECREF(quantize_method); + if (rounding_const_obj != Py_None) { + Py_DECREF(rounding_const_obj); + } + + return quantized_obj; +} + + static PyObject * ms_decode_decimal_from_pyobj(PyObject *str, PathNode *path, MsgspecState *mod) { if (mod == NULL) { @@ -12414,7 +12624,8 @@ maybe_parse_number( *************************************************************************/ PyDoc_STRVAR(Encoder__doc__, -"Encoder(*, enc_hook=None, decimal_format='string', uuid_format='canonical', order=None)\n" +"Encoder(*, enc_hook=None, decimal_format='string', uuid_format='canonical', order=None,\n" +" decimal_quantize=None, decimal_rounding=None)\n" "--\n" "\n" "A MessagePack encoder.\n" @@ -12446,7 +12657,20 @@ PyDoc_STRVAR(Encoder__doc__, " of the encoded binary output is necessary.\n" " - `'sorted'`: Like `'deterministic'`, but *all* object-like types (structs,\n" " dataclasses, ...) are also sorted by field name before encoding. This is\n" -" slower than `'deterministic'`, but may produce more human-readable output." +" slower than `'deterministic'`, but may produce more human-readable output.\n" +"\n" +"decimal_quantize : decimal.Decimal, optional\n" +" An optional Decimal value specifying the quantize target for rounding\n" +" Decimal numbers before encoding. If provided, all Decimal values will be\n" +" quantized (rounded) to the same precision, scale, and exponent as this value.\n" +" This is useful for ensuring consistent precision of decimal numbers during\n" +" serialization.\n" +"decimal_rounding : {None, 'ROUND_DOWN', 'ROUND_HALF_UP', 'ROUND_HALF_EVEN',\n" +" 'ROUND_CEILING', 'ROUND_FLOOR', 'ROUND_UP', 'ROUND_HALF_DOWN', 'ROUND_05UP'}, optional\n" +" The rounding mode to use when quantizing Decimal values. Must be one of the\n" +" standard decimal module rounding modes. If not specified, the default\n" +" rounding mode is used (ROUND_HALF_EVEN). This parameter has no\n" +" effect unless `decimal_quantize` is also specified.\n" ); enum mpack_code { @@ -13303,16 +13527,23 @@ mpack_encode_uuid(EncoderState *self, PyObject *obj) static int mpack_encode_decimal(EncoderState *self, PyObject *obj) { - PyObject *temp; + PyObject *temp = NULL, *quantized_obj = NULL; int out; + if (self->decimal_quantize != NULL) { + quantized_obj = quantize_decimal_obj(self, obj); + if (quantized_obj == NULL) return -1; + } + if (MS_LIKELY(self->decimal_format == DECIMAL_FORMAT_STRING)) { - temp = PyObject_Str(obj); + temp = MS_LIKELY(quantized_obj == NULL) ? PyObject_Str(obj) : PyObject_Str(quantized_obj); + Py_XDECREF(quantized_obj); if (temp == NULL) return -1; out = mpack_encode_str(self, temp); } else { - temp = PyNumber_Float(obj); + temp = MS_LIKELY(quantized_obj == NULL) ? PyNumber_Float(obj) : PyNumber_Float(quantized_obj); + Py_XDECREF(quantized_obj); if (temp == NULL) return -1; out = mpack_encode_float(self, temp); } @@ -13660,6 +13891,19 @@ PyDoc_STRVAR(JSONEncoder__doc__, " - `'sorted'`: Like `'deterministic'`, but *all* object-like types (structs,\n" " dataclasses, ...) are also sorted by field name before encoding. This is\n" " slower than `'deterministic'`, but may produce more human-readable output." +"\n" +"decimal_quantize : decimal.Decimal, optional\n" +" An optional Decimal value specifying the quantize target for rounding\n" +" Decimal numbers before encoding. If provided, all Decimal values will be\n" +" quantized (rounded) to the same precision, scale, and exponent as this value.\n" +" This is useful for ensuring consistent precision of decimal numbers during\n" +" serialization.\n" +"decimal_rounding : {None, 'ROUND_DOWN', 'ROUND_HALF_UP', 'ROUND_HALF_EVEN',\n" +" 'ROUND_CEILING', 'ROUND_FLOOR', 'ROUND_UP', 'ROUND_HALF_DOWN', 'ROUND_05UP'}, optional\n" +" The rounding mode to use when quantizing Decimal values. Must be one of the\n" +" standard decimal module rounding modes. If not specified, the default\n" +" rounding mode is used (ROUND_HALF_EVEN). This parameter has no\n" +" effect unless `decimal_quantize` is also specified.\n" ); static int json_encode_inline(EncoderState*, PyObject*); @@ -13951,7 +14195,17 @@ json_encode_uuid(EncoderState *self, PyObject *obj) static int json_encode_decimal(EncoderState *self, PyObject *obj) { - PyObject *temp = PyObject_Str(obj); + PyObject *temp = NULL; + + if (MS_LIKELY(self->decimal_quantize == NULL)) { + temp = PyObject_Str(obj); + } else { + PyObject *quantized_obj = quantize_decimal_obj(self, obj); + if (quantized_obj == NULL) return -1; + temp = PyObject_Str(quantized_obj); + Py_DECREF(quantized_obj); + } + if (temp == NULL) return -1; Py_ssize_t size; @@ -14669,6 +14923,8 @@ JSONEncoder_encode_lines(Encoder *self, PyObject *const *args, Py_ssize_t nargs) .decimal_format = self->decimal_format, .uuid_format = self->uuid_format, .order = self->order, + .decimal_quantize = self->decimal_quantize, + .decimal_rounding = self->decimal_rounding, .output_len = 0, .max_output_len = ENC_LINES_INIT_BUFSIZE, .resize_buffer = &ms_resize_bytes diff --git a/src/msgspec/json.pyi b/src/msgspec/json.pyi index dc9172ed..af139d6f 100644 --- a/src/msgspec/json.pyi +++ b/src/msgspec/json.pyi @@ -1,3 +1,4 @@ +import decimal from collections.abc import Callable, Iterable from typing import ( Any, @@ -26,6 +27,18 @@ class Encoder: decimal_format: Literal["string", "number"] uuid_format: Literal["canonical", "hex"] order: Literal[None, "deterministic", "sorted"] + decimal_quantize: Optional[decimal.Decimal] + decimal_rounding: Literal[ + None, + "ROUND_DOWN", + "ROUND_HALF_UP", + "ROUND_HALF_EVEN", + "ROUND_CEILING", + "ROUND_FLOOR", + "ROUND_UP", + "ROUND_HALF_DOWN", + "ROUND_05UP", + ] def __init__( self, @@ -34,6 +47,18 @@ class Encoder: decimal_format: Literal["string", "number"] = "string", uuid_format: Literal["canonical", "hex"] = "canonical", order: Literal[None, "deterministic", "sorted"] = None, + decimal_quantize: Optional[decimal.Decimal] = None, + decimal_rounding: Literal[ + None, + "ROUND_DOWN", + "ROUND_HALF_UP", + "ROUND_HALF_EVEN", + "ROUND_CEILING", + "ROUND_FLOOR", + "ROUND_UP", + "ROUND_HALF_DOWN", + "ROUND_05UP", + ] = None, ): ... def encode(self, obj: Any, /) -> bytes: ... def encode_lines(self, items: Iterable, /) -> bytes: ... diff --git a/src/msgspec/msgpack.pyi b/src/msgspec/msgpack.pyi index 4eb5e38d..b9b3c000 100644 --- a/src/msgspec/msgpack.pyi +++ b/src/msgspec/msgpack.pyi @@ -1,3 +1,4 @@ +import decimal from typing import ( Any, Callable, @@ -63,6 +64,19 @@ class Encoder: decimal_format: Literal["string", "number"] uuid_format: Literal["canonical", "hex", "bytes"] order: Literal[None, "deterministic", "sorted"] + decimal_quantize: Optional[decimal.Decimal] + decimal_rounding: Literal[ + None, + "ROUND_DOWN", + "ROUND_HALF_UP", + "ROUND_HALF_EVEN", + "ROUND_CEILING", + "ROUND_FLOOR", + "ROUND_UP", + "ROUND_HALF_DOWN", + "ROUND_05UP", + ] + def __init__( self, *, @@ -70,6 +84,18 @@ class Encoder: decimal_format: Literal["string", "number"] = "string", uuid_format: Literal["canonical", "hex", "bytes"] = "canonical", order: Literal[None, "deterministic", "sorted"] = None, + decimal_quantize: Optional[decimal.Decimal] = None, + decimal_rounding: Literal[ + None, + "ROUND_DOWN", + "ROUND_HALF_UP", + "ROUND_HALF_EVEN", + "ROUND_CEILING", + "ROUND_FLOOR", + "ROUND_UP", + "ROUND_HALF_DOWN", + "ROUND_05UP", + ] = None, ): ... def encode(self, obj: Any, /) -> bytes: ... def encode_into( diff --git a/tests/unit/test_common.py b/tests/unit/test_common.py index 956c0d3a..890b0e3f 100644 --- a/tests/unit/test_common.py +++ b/tests/unit/test_common.py @@ -3934,6 +3934,155 @@ def test_encode_decimal(self, proto): s = str(d) assert proto.encode(d) == proto.encode(s) + def test_encoder_decimal_quantize(self, proto): + assert proto.Encoder().decimal_quantize is None + assert proto.Encoder( + decimal_quantize=decimal.Decimal("1.00"), + ).decimal_quantize == decimal.Decimal("1.00") + + def test_encoder_decimal_rounding(self, proto): + assert proto.Encoder().decimal_rounding is None + assert ( + proto.Encoder(decimal_rounding=decimal.ROUND_UP).decimal_rounding + == decimal.ROUND_UP + ) + + @pytest.mark.parametrize( + ("decimal_quantize", "expected"), + [ + (decimal.Decimal("1.00000"), "1.34490"), + (decimal.Decimal("1.0000"), "1.3449"), + (decimal.Decimal("1.000"), "1.345"), + (decimal.Decimal("1.00"), "1.34"), + (decimal.Decimal("-1.0"), "1.3"), + (decimal.Decimal("1.0"), "1.3"), + (decimal.Decimal("1"), "1"), + (decimal.Decimal("0"), "1"), + ], + ) + def test_Encoder_encode_decimal_quantize( + self, + proto, + decimal_quantize: decimal.Decimal, + expected: str, + ): + enc = proto.Encoder(decimal_quantize=decimal_quantize) + msg = enc.encode(decimal.Decimal("1.3449")) + assert msg == enc.encode(expected) + + @pytest.mark.parametrize( + ("decimal_quantize", "expected"), + [ + (decimal.Decimal("1.0000"), "1.3449"), + (decimal.Decimal("1.000"), "1.345"), + (decimal.Decimal("1.00"), "1.34"), + (decimal.Decimal("1.0"), "1.3"), + (decimal.Decimal("-1.0"), "1.3"), + ], + ) + def test_Encoder_encode_decimal_quantize_to_number( + self, + proto, + decimal_quantize: decimal.Decimal, + expected: str, + ): + enc = proto.Encoder( + decimal_format="number", + decimal_quantize=decimal_quantize, + ) + msg = enc.encode(decimal.Decimal("1.3449")) + assert msg == enc.encode(float(expected)) + + @pytest.mark.parametrize( + ("rounding", "value", "expected"), + [ + pytest.param( + None, + decimal.Decimal("0.446"), + "0.45", + id="ROUND_HALF_EVEN by default", + ), + pytest.param( + None, + decimal.Decimal("-0.445"), + "-0.44", + id="ROUND_HALF_EVEN by default", + ), + # ROUND_DOWN + (decimal.ROUND_DOWN, decimal.Decimal("3.706"), "3.70"), + (decimal.ROUND_DOWN, decimal.Decimal("2.206"), "2.20"), + (decimal.ROUND_DOWN, decimal.Decimal("-1.504"), "-1.50"), + # ROUND_UP + (decimal.ROUND_UP, decimal.Decimal("2.101"), "2.11"), + (decimal.ROUND_UP, decimal.Decimal("2.501"), "2.51"), + (decimal.ROUND_UP, decimal.Decimal("-2.101"), "-2.11"), + (decimal.ROUND_UP, decimal.Decimal("-2.501"), "-2.51"), + # ROUND_HALF_UP + (decimal.ROUND_HALF_UP, decimal.Decimal("0.444"), "0.44"), + (decimal.ROUND_HALF_UP, decimal.Decimal("0.445"), "0.45"), + (decimal.ROUND_HALF_UP, decimal.Decimal("-0.554"), "-0.55"), + (decimal.ROUND_HALF_UP, decimal.Decimal("-0.555"), "-0.56"), + # ROUND_HALF_EVEN + (decimal.ROUND_HALF_EVEN, decimal.Decimal("-0.445"), "-0.44"), + (decimal.ROUND_HALF_EVEN, decimal.Decimal("0.446"), "0.45"), + (decimal.ROUND_HALF_EVEN, decimal.Decimal("0.555"), "0.56"), + # ROUND_CEILING + (decimal.ROUND_CEILING, decimal.Decimal("0.440000"), "0.44"), + (decimal.ROUND_CEILING, decimal.Decimal("0.440001"), "0.45"), + (decimal.ROUND_CEILING, decimal.Decimal("0.5500001"), "0.56"), + (decimal.ROUND_CEILING, decimal.Decimal("-0.5500001"), "-0.55"), + # ROUND_FLOOR + (decimal.ROUND_FLOOR, decimal.Decimal("0.4488"), "0.44"), + (decimal.ROUND_FLOOR, decimal.Decimal("0.445"), "0.44"), + (decimal.ROUND_FLOOR, decimal.Decimal("0.5595"), "0.55"), + (decimal.ROUND_FLOOR, decimal.Decimal("-0.5595"), "-0.56"), + # ROUND_HALF_DOWN + (decimal.ROUND_HALF_DOWN, decimal.Decimal("0.445"), "0.44"), + (decimal.ROUND_HALF_DOWN, decimal.Decimal("0.446"), "0.45"), + (decimal.ROUND_HALF_DOWN, decimal.Decimal("0.555"), "0.55"), + (decimal.ROUND_HALF_DOWN, decimal.Decimal("0.556"), "0.56"), + # ROUND_05UP + (decimal.ROUND_05UP, decimal.Decimal("10.00001"), "10.01"), + (decimal.ROUND_05UP, decimal.Decimal("10.050101"), "10.06"), + (decimal.ROUND_05UP, decimal.Decimal("10.018"), "10.01"), + (decimal.ROUND_05UP, decimal.Decimal("10.068"), "10.06"), + ], + ) + def test_Encoder_encode_decimal_rounding( + self, + proto, + rounding: str | None, + value: decimal.Decimal, + expected: str, + ): + enc = msgspec.msgpack.Encoder( + decimal_quantize=decimal.Decimal("1.00"), + decimal_rounding=rounding, + ) + msg = enc.encode(value) + assert msg == enc.encode(expected) + + def test_encoder_invalid_decimal_rounding( + self, + proto, + ): + with pytest.raises( + ValueError, + match="`decimal_rounding` must be 'ROUND_DOWN'", + ): + proto.Encoder( + decimal_quantize=decimal.Decimal("1"), + decimal_rounding="DIFFERENT_VALUE", + ) + + def test_encoder_decimal_rounding_has_no_effect( + self, + proto, + ): + enc = msgspec.msgpack.Encoder(decimal_rounding=decimal.ROUND_UP) + msg = enc.encode(decimal.Decimal("1.99")) + assert msg == enc.encode("1.99") + @pytest.mark.parametrize( "val", ["1.5", "InF", "-iNf", "iNfInItY", "-InFiNiTy", "NaN"] )