From 4fbd9b4e74bb1c93a8a953f41a3695d25eea8977 Mon Sep 17 00:00:00 2001 From: Takashi Imamichi Date: Fri, 24 May 2024 15:15:49 +0900 Subject: [PATCH] use mapping-like features of DataBin --- qiskit_ibm_runtime/utils/json.py | 7 +++---- test/integration/test_sampler_v2.py | 6 +++--- test/unit/test_data_serialization.py | 14 +++++--------- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/qiskit_ibm_runtime/utils/json.py b/qiskit_ibm_runtime/utils/json.py index b2b005a28..cf180d374 100644 --- a/qiskit_ibm_runtime/utils/json.py +++ b/qiskit_ibm_runtime/utils/json.py @@ -270,10 +270,9 @@ def default(self, obj: Any) -> Any: # pylint: disable=arguments-differ return {"__type__": "BitArray", "__value__": out_val} if isinstance(obj, DataBin): out_val = { - "field_names": obj._FIELDS, - "field_types": [str(field_type) for field_type in obj._FIELD_TYPES], - "shape": obj._SHAPE, - "fields": {field_name: getattr(obj, field_name) for field_name in obj._FIELDS}, + "field_names": list(obj), + "shape": obj.shape, + "fields": dict(obj.items()), } return {"__type__": "DataBin", "__value__": out_val} if isinstance(obj, EstimatorPub): diff --git a/test/integration/test_sampler_v2.py b/test/integration/test_sampler_v2.py index 6647fb45b..783b69f41 100644 --- a/test/integration/test_sampler_v2.py +++ b/test/integration/test_sampler_v2.py @@ -513,10 +513,10 @@ def test_circuit_with_multiple_cregs(self, service): result = sampler.run([qc]).result() self.assertEqual(len(result), 1) data = result[0].data - self.assertEqual(len(data._FIELDS), 3) + self.assertEqual(len(data), 3) for creg in qc.cregs: - self.assertTrue(hasattr(data, creg.name)) - self._assert_allclose(getattr(data, creg.name), np.array(target[creg.name])) + self.assertIn(creg.name, data) + self._assert_allclose(data[creg.name], np.array(target[creg.name])) @run_integration_test def test_samplerv2_options(self, service): diff --git a/test/unit/test_data_serialization.py b/test/unit/test_data_serialization.py index e2be4f306..a08bac75e 100644 --- a/test/unit/test_data_serialization.py +++ b/test/unit/test_data_serialization.py @@ -296,15 +296,11 @@ def assert_data_bins_equal(self, dbin1, dbin2): """Compares two DataBins Field types are compared up to their string representation """ - self.assertEqual(dbin1._FIELDS, dbin2._FIELDS) - self.assertEqual( - [str(field_type) for field_type in dbin1._FIELD_TYPES], - [str(field_type) for field_type in dbin2._FIELD_TYPES], - ) - self.assertEqual(dbin1._SHAPE, dbin2._SHAPE) - for field_name in dbin1._FIELDS: - field_1 = getattr(dbin1, field_name) - field_2 = getattr(dbin2, field_name) + self.assertEqual(tuple(dbin1), tuple(dbin2)) + self.assertEqual(dbin1.shape, dbin2.shape) + for field_name in dbin1: + field_1 = dbin1[field_name] + field_2 = dbin2[field_name] if isinstance(field_1, np.ndarray): np.testing.assert_allclose(field_1, field_2) else: