diff --git a/msgpack/_packer.pyx b/msgpack/_packer.pyx index 94d1462c..75b3f7f5 100644 --- a/msgpack/_packer.pyx +++ b/msgpack/_packer.pyx @@ -291,6 +291,10 @@ cdef class Packer: raise ValueError("ext data too large") msgpack_pack_ext(&self.pk, typecode, len(data)) msgpack_pack_raw_body(&self.pk, data, len(data)) + if self.autoreset: + buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) + self.pk.length = 0 + return buf @cython.critical_section def pack_array_header(self, long long size): diff --git a/msgpack/fallback.py b/msgpack/fallback.py index b02e47cf..6b54ffbf 100644 --- a/msgpack/fallback.py +++ b/msgpack/fallback.py @@ -861,6 +861,10 @@ def pack_ext_type(self, typecode, data): self._buffer.write(b"\xc9" + struct.pack(">I", L)) self._buffer.write(struct.pack("B", typecode)) self._buffer.write(data) + if self._autoreset: + ret = self._buffer.getvalue() + self._buffer = BytesIO() + return ret def _pack_array_header(self, n): if n <= 0x0F: diff --git a/pyproject.toml b/pyproject.toml index c69d5a7c..4e4510b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] +dependencies = [] [project.urls] Homepage = "https://msgpack.org/" @@ -43,3 +44,9 @@ lint.select = [ "I", # isort #"UP", pyupgrade ] + +[dependency-groups] +dev = [ + "cython>=3.2.5", + "pytest>=9.0.3", +] diff --git a/test/test_extension.py b/test/test_extension.py index aaf0fd92..61852f15 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -6,7 +6,7 @@ def test_pack_ext_type(): def p(s): - packer = msgpack.Packer() + packer = msgpack.Packer(autoreset=False) packer.pack_ext_type(0x42, s) return packer.bytes() @@ -20,6 +20,15 @@ def p(s): assert p(b"A" * 0x00012345) == b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345 # ext 32 +def test_pack_ext_type_autoreset(): + packer = msgpack.Packer() + + assert packer.pack_ext_type(0x42, b"A") == b"\xd4\x42A" + assert packer.bytes() == b"" + assert packer.pack_ext_type(0x42, b"ABC") == b"\xc7\x03\x42ABC" + assert packer.bytes() == b"" + + def test_unpack_ext_type(): def check(b, expected): assert msgpack.unpackb(b) == expected