Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-4204 Optimize JSON decoding using lookup table to find $ keys #1512

Merged
merged 7 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .evergreen/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ fi
PIP_QUIET=0 python -m pip list

if [ -z "$GREEN_FRAMEWORK" ]; then
python -m pytest -v --durations=5 --maxfail=10 $TEST_ARGS
# Use --capture=tee-sys so pytest prints test output inline:
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
Copy link
Member Author

@ShaneHarvey ShaneHarvey Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lets us see the print()s in the pytest output:

 [2024/02/06 14:45:55.785] + python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/performance/perf_test.py
 [2024/02/06 14:45:56.173] ============================= test session starts ==============================
 [2024/02/06 14:45:56.173] platform linux -- Python 3.10.4, pytest-8.0.0, pluggy-1.4.0 -- /data/mci/6a62d81f9a459de6454c8a15a2244eb2/src/.tox/test-eg/bin/python
 [2024/02/06 14:45:56.173] cachedir: .tox/test-eg/.pytest_cache
 [2024/02/06 14:45:56.173] rootdir: /data/mci/6a62d81f9a459de6454c8a15a2244eb2/src
 [2024/02/06 14:45:56.173] configfile: pyproject.toml
 [2024/02/06 14:45:56.335] collecting ... collected 25 items
 [2024/02/06 14:45:56.442] test/performance/perf_test.py::TestFlatEncoding::runTest Completed TestFlatEncoding 330.739 MB/s, MEDIAN=0.018s, total time=0.183s
 [2024/02/06 14:45:56.666] PASSED          [  4%]
 [2024/02/06 14:45:56.666] test/performance/perf_test.py::TestFlatDecoding::runTest Completed TestFlatDecoding 278.886 MB/s, MEDIAN=0.022s, total time=0.217s
 [2024/02/06 14:45:56.833] PASSED          [  8%]

else
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
fi
Expand Down
120 changes: 62 additions & 58 deletions bson/json_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,54 +526,17 @@ def object_pairs_hook(


def object_hook(dct: Mapping[str, Any], json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any:
if "$oid" in dct:
return _parse_canonical_oid(dct)
if (
isinstance(dct.get("$ref"), str)
and "$id" in dct
and isinstance(dct.get("$db"), (str, type(None)))
):
return _parse_canonical_dbref(dct)
if "$date" in dct:
return _parse_canonical_datetime(dct, json_options)
if "$regex" in dct:
return _parse_legacy_regex(dct)
if "$minKey" in dct:
return _parse_canonical_minkey(dct)
if "$maxKey" in dct:
return _parse_canonical_maxkey(dct)
if "$binary" in dct:
if "$type" in dct:
return _parse_legacy_binary(dct, json_options)
else:
return _parse_canonical_binary(dct, json_options)
if "$code" in dct:
return _parse_canonical_code(dct)
if "$uuid" in dct:
return _parse_legacy_uuid(dct, json_options)
if "$undefined" in dct:
return None
if "$numberLong" in dct:
return _parse_canonical_int64(dct)
if "$timestamp" in dct:
tsp = dct["$timestamp"]
return Timestamp(tsp["t"], tsp["i"])
if "$numberDecimal" in dct:
return _parse_canonical_decimal128(dct)
if "$dbPointer" in dct:
return _parse_canonical_dbpointer(dct)
if "$regularExpression" in dct:
return _parse_canonical_regex(dct)
if "$symbol" in dct:
return _parse_canonical_symbol(dct)
if "$numberInt" in dct:
return _parse_canonical_int32(dct)
if "$numberDouble" in dct:
return _parse_canonical_double(dct)
match = None
for k in dct:
if k in _PARSERS_SET:
match = k
break
if match:
return _PARSERS[match](dct, json_options)
return dct


def _parse_legacy_regex(doc: Any) -> Any:
def _parse_legacy_regex(doc: Any, dummy0: Any) -> Any:
pattern = doc["$regex"]
# Check if this is the $regex query operator.
if not isinstance(pattern, (str, bytes)):
Expand Down Expand Up @@ -709,30 +672,30 @@ def _parse_canonical_datetime(
return _millis_to_datetime(int(dtm), cast("CodecOptions[Any]", json_options))


def _parse_canonical_oid(doc: Any) -> ObjectId:
def _parse_canonical_oid(doc: Any, dummy0: Any) -> ObjectId:
"""Decode a JSON ObjectId to bson.objectid.ObjectId."""
if len(doc) != 1:
raise TypeError(f"Bad $oid, extra field(s): {doc}")
return ObjectId(doc["$oid"])


def _parse_canonical_symbol(doc: Any) -> str:
def _parse_canonical_symbol(doc: Any, dummy0: Any) -> str:
"""Decode a JSON symbol to Python string."""
symbol = doc["$symbol"]
if len(doc) != 1:
raise TypeError(f"Bad $symbol, extra field(s): {doc}")
return str(symbol)


def _parse_canonical_code(doc: Any) -> Code:
def _parse_canonical_code(doc: Any, dummy0: Any) -> Code:
"""Decode a JSON code to bson.code.Code."""
for key in doc:
if key not in ("$code", "$scope"):
raise TypeError(f"Bad $code, extra field(s): {doc}")
return Code(doc["$code"], scope=doc.get("$scope"))


def _parse_canonical_regex(doc: Any) -> Regex[str]:
def _parse_canonical_regex(doc: Any, dummy0: Any) -> Regex[str]:
"""Decode a JSON regex to bson.regex.Regex."""
regex = doc["$regularExpression"]
if len(doc) != 1:
Expand All @@ -749,12 +712,18 @@ def _parse_canonical_regex(doc: Any) -> Regex[str]:
return Regex(regex["pattern"], opts)


def _parse_canonical_dbref(doc: Any) -> DBRef:
def _parse_canonical_dbref(doc: Any, dummy0: Any) -> Any:
"""Decode a JSON DBRef to bson.dbref.DBRef."""
return DBRef(doc.pop("$ref"), doc.pop("$id"), database=doc.pop("$db", None), **doc)
if (
isinstance(doc.get("$ref"), str)
and "$id" in doc
and isinstance(doc.get("$db"), (str, type(None)))
):
return DBRef(doc.pop("$ref"), doc.pop("$id"), database=doc.pop("$db", None), **doc)
return doc


def _parse_canonical_dbpointer(doc: Any) -> Any:
def _parse_canonical_dbpointer(doc: Any, dummy0: Any) -> Any:
"""Decode a JSON (deprecated) DBPointer to bson.dbref.DBRef."""
dbref = doc["$dbPointer"]
if len(doc) != 1:
Expand All @@ -773,7 +742,7 @@ def _parse_canonical_dbpointer(doc: Any) -> Any:
raise TypeError(f"Bad $dbPointer, expected a DBRef: {doc}")


def _parse_canonical_int32(doc: Any) -> int:
def _parse_canonical_int32(doc: Any, dummy0: Any) -> int:
"""Decode a JSON int32 to python int."""
i_str = doc["$numberInt"]
if len(doc) != 1:
Expand All @@ -783,15 +752,15 @@ def _parse_canonical_int32(doc: Any) -> int:
return int(i_str)


def _parse_canonical_int64(doc: Any) -> Int64:
def _parse_canonical_int64(doc: Any, dummy0: Any) -> Int64:
"""Decode a JSON int64 to bson.int64.Int64."""
l_str = doc["$numberLong"]
if len(doc) != 1:
raise TypeError(f"Bad $numberLong, extra field(s): {doc}")
return Int64(l_str)


def _parse_canonical_double(doc: Any) -> float:
def _parse_canonical_double(doc: Any, dummy0: Any) -> float:
"""Decode a JSON double to python float."""
d_str = doc["$numberDouble"]
if len(doc) != 1:
Expand All @@ -801,7 +770,7 @@ def _parse_canonical_double(doc: Any) -> float:
return float(d_str)


def _parse_canonical_decimal128(doc: Any) -> Decimal128:
def _parse_canonical_decimal128(doc: Any, dummy0: Any) -> Decimal128:
"""Decode a JSON decimal128 to bson.decimal128.Decimal128."""
d_str = doc["$numberDecimal"]
if len(doc) != 1:
Expand All @@ -811,7 +780,7 @@ def _parse_canonical_decimal128(doc: Any) -> Decimal128:
return Decimal128(d_str)


def _parse_canonical_minkey(doc: Any) -> MinKey:
def _parse_canonical_minkey(doc: Any, dummy0: Any) -> MinKey:
"""Decode a JSON MinKey to bson.min_key.MinKey."""
if type(doc["$minKey"]) is not int or doc["$minKey"] != 1: # noqa: E721
raise TypeError(f"$minKey value must be 1: {doc}")
Expand All @@ -820,7 +789,7 @@ def _parse_canonical_minkey(doc: Any) -> MinKey:
return MinKey()


def _parse_canonical_maxkey(doc: Any) -> MaxKey:
def _parse_canonical_maxkey(doc: Any, dummy0: Any) -> MaxKey:
"""Decode a JSON MaxKey to bson.max_key.MaxKey."""
if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1: # noqa: E721
raise TypeError("$maxKey value must be 1: %s", (doc,))
Expand All @@ -829,6 +798,41 @@ def _parse_canonical_maxkey(doc: Any) -> MaxKey:
return MaxKey()


def _parse_binary(doc: Any, json_options: JSONOptions) -> Union[Binary, uuid.UUID]:
if "$type" in doc:
return _parse_legacy_binary(doc, json_options)
else:
return _parse_canonical_binary(doc, json_options)


def _parse_timestamp(doc: Any, dummy0: Any) -> Timestamp:
tsp = doc["$timestamp"]
return Timestamp(tsp["t"], tsp["i"])


_PARSERS: dict[str, Callable[[Any, JSONOptions], Any]] = {
"$oid": _parse_canonical_oid,
"$ref": _parse_canonical_dbref,
"$date": _parse_canonical_datetime,
"$regex": _parse_legacy_regex,
"$minKey": _parse_canonical_minkey,
"$maxKey": _parse_canonical_maxkey,
"$binary": _parse_binary,
"$code": _parse_canonical_code,
"$uuid": _parse_legacy_uuid,
"$undefined": lambda _, _1: None,
"$numberLong": _parse_canonical_int64,
"$timestamp": _parse_timestamp,
"$numberDecimal": _parse_canonical_decimal128,
"$dbPointer": _parse_canonical_dbpointer,
"$regularExpression": _parse_canonical_regex,
"$symbol": _parse_canonical_symbol,
"$numberInt": _parse_canonical_int32,
"$numberDouble": _parse_canonical_double,
}
_PARSERS_SET = set(_PARSERS)


def _encode_binary(data: bytes, subtype: int, json_options: JSONOptions) -> Any:
if json_options.json_mode == JSONMode.LEGACY:
return {"$binary": base64.b64encode(data).decode(), "$type": "%02x" % subtype}
Expand Down
13 changes: 11 additions & 2 deletions test/performance/perf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,17 @@ def setUpClass(cls):
client_context.init()

def setUp(self):
pass
self.setup_time = time.monotonic()

def tearDown(self):
duration = time.monotonic() - self.setup_time
# Remove "Test" so that TestFlatEncoding is reported as "FlatEncoding".
name = self.__class__.__name__[4:]
median = self.percentile(50)
megabytes_per_sec = self.data_size / median / 1000000
print(
f"Running {self.__class__.__name__}. MB/s={megabytes_per_sec}, MEDIAN={self.percentile(50)}"
f"Completed {self.__class__.__name__} {megabytes_per_sec:.3f} MB/s, MEDIAN={self.percentile(50):.3f}s, "
f"total time={duration:.3f}s"
)
result_data.append(
{
Expand Down Expand Up @@ -149,6 +151,7 @@ def mp_map(self, map_func, files):

class MicroTest(PerformanceTest):
def setUp(self):
super().setUp()
# Location of test data.
with open(os.path.join(TEST_PATH, os.path.join("extended_bson", self.dataset))) as data:
self.file_data = data.read()
Expand Down Expand Up @@ -256,6 +259,7 @@ class TestRunCommand(PerformanceTest, unittest.TestCase):
data_size = len(encode({"hello": True})) * NUM_DOCS

def setUp(self):
super().setUp()
self.client = client_context.client
self.client.drop_database("perftest")

Expand All @@ -267,6 +271,7 @@ def do_task(self):

class TestDocument(PerformanceTest):
def setUp(self):
super().setUp()
# Location of test data.
with open(
os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset))
Expand Down Expand Up @@ -458,6 +463,7 @@ def read_gridfs_file(filename):

class TestJsonMultiImport(PerformanceTest, unittest.TestCase):
def setUp(self):
super().setUp()
self.client = client_context.client
self.client.drop_database("perftest")
ldjson_path = os.path.join(TEST_PATH, os.path.join("parallel", "ldjson_multi"))
Expand All @@ -481,6 +487,7 @@ def tearDown(self):

class TestJsonMultiExport(PerformanceTest, unittest.TestCase):
def setUp(self):
super().setUp()
self.client = client_context.client
self.client.drop_database("perftest")
self.client.perfest.corpus.create_index("file")
Expand All @@ -501,6 +508,7 @@ def tearDown(self):

class TestGridFsMultiFileUpload(PerformanceTest, unittest.TestCase):
def setUp(self):
super().setUp()
self.client = client_context.client
self.client.drop_database("perftest")
gridfs_path = os.path.join(TEST_PATH, os.path.join("parallel", "gridfs_multi"))
Expand All @@ -525,6 +533,7 @@ def tearDown(self):

class TestGridFsMultiFileDownload(PerformanceTest, unittest.TestCase):
def setUp(self):
super().setUp()
self.client = client_context.client
self.client.drop_database("perftest")

Expand Down
Loading