Skip to content

Commit

Permalink
Merge pull request #190 from google/gbm_strings_are_bytes
Browse files Browse the repository at this point in the history
Internally, tp uses bytes instead of int32 unicodes
  • Loading branch information
achoum authored Jul 12, 2023
2 parents 2410cd3 + 9210c9d commit 4e27d4f
Show file tree
Hide file tree
Showing 47 changed files with 4,881 additions and 3,449 deletions.
1 change: 1 addition & 0 deletions docs/public_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"int64",
"bool_",
"str_",
"bytes_",
# SERIALIZATION
"save",
"load",
Expand Down
Empty file.
382 changes: 195 additions & 187 deletions docs/src/tutorials/getting_started.ipynb

Large diffs are not rendered by default.

2,696 changes: 1,380 additions & 1,316 deletions docs/src/tutorials/loan_outcomes_prediction.ipynb

Large diffs are not rendered by default.

4,972 changes: 3,097 additions & 1,875 deletions docs/src/user_guide.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions temporian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from temporian.core.data.dtype import int64
from temporian.core.data.dtype import bool_
from temporian.core.data.dtype import str_
from temporian.core.data.dtype import bytes_

# Schema
from temporian.core.data.schema import Schema
Expand Down
14 changes: 10 additions & 4 deletions temporian/beam/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities to import/export Beam-Event-Set from/to dataset containers."""

from typing import Iterable, Dict, Any, Tuple, Union, Optional
from typing import Iterable, Dict, Any, Tuple, Union, Optional, List

import csv
import io
Expand All @@ -27,7 +27,7 @@
# In the numpy backend, index are represented as numpy primitives. However,
# Beam does not support numpy primitive as index. Therefore, all index are
# converted to python primitive of type "BeamIndex".
BeamIndex = Union[int, float, str, bool]
BeamIndex = Union[int, float, str, bytes, bool]

# Temporian index or Feature index in Beam.
#
Expand Down Expand Up @@ -265,14 +265,18 @@ def read_csv(
)


def _bytes_to_strs(list: List) -> List:
return [x.decode() if isinstance(x, bytes) else x for x in list]


def _convert_to_csv(
item: Tuple[
Tuple[BeamIndex, ...],
Iterable[IndexValue],
]
) -> str:
index, feature_blocks = item
index_data = list(index)
index_data = _bytes_to_strs(list(index))

# Sort the feature by feature index.
# The feature index is the last value (-1) of the key (first element of the
Expand All @@ -286,7 +290,9 @@ def _convert_to_csv(
output = io.StringIO()
writer = csv.writer(output)
for event_idx, timestamp in enumerate(timestamps):
feature_data = [f[1][1][event_idx] for f in feature_blocks]
feature_data = _bytes_to_strs(
[f[1][1][event_idx] for f in feature_blocks]
)
writer.writerow([timestamp] + index_data + feature_data)

return output.getvalue()
Expand Down
14 changes: 10 additions & 4 deletions temporian/core/data/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def is_float(self) -> bool:
def is_integer(self) -> bool:
return self in (DType.INT64, DType.INT32)

def missing_value(self) -> Union[float, int, str]:
def missing_value(self) -> Union[float, int, bytes]:
"""
Returns missing value for specific dtype.
Expand All @@ -58,7 +58,7 @@ def missing_value(self) -> Union[float, int, str]:
return 0

if self == DType.STRING:
return ""
return b""

raise ValueError(f"Non-implemented type {self}")

Expand Down Expand Up @@ -86,6 +86,9 @@ def from_python_type(cls, python_type: type) -> "DType":
if python_type is str:
return DType.STRING

if python_type is bytes:
return DType.STRING

if python_type is bool:
return DType.BOOLEAN

Expand Down Expand Up @@ -124,5 +127,8 @@ def check_is_valid_index_dtype(dtype: DType):
bool_ = DType.BOOLEAN
"""Boolean value."""

str_ = DType.STRING
"""String value."""
bytes_ = DType.STRING
"""String value (stored as bytes)."""

str_ = bytes_
"""String value (stored as bytes)."""
7 changes: 6 additions & 1 deletion temporian/core/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# Valid types for operator attributes
AttributeType = Union[
str, int, float, bool, List[str], Dict[str, str], List[DType]
str, int, float, bool, bytes, List[str], Dict[str, str], List[DType]
]

# Generic type for defining the input/output types of operators.
Expand Down Expand Up @@ -254,6 +254,10 @@ def is_list_dtype(value):
and not isinstance(value, str)
):
raise ValueError(f"Attribute {value=} mismatch: string expected")
if attr_type == pb.OperatorDef.Attribute.Type.BYTES and not isinstance(
value, bytes
):
raise ValueError(f"Attribute {value=} mismatch: string expected")
if (
attr_type == pb.OperatorDef.Attribute.Type.INTEGER_64
and not isinstance(value, int)
Expand Down Expand Up @@ -294,6 +298,7 @@ def is_list_dtype(value):
if (
attr_type == pb.OperatorDef.Attribute.Type.ANY
and not isinstance(value, str)
and not isinstance(value, bytes)
and not isinstance(value, bool)
and not isinstance(value, int)
and not isinstance(value, float)
Expand Down
6 changes: 3 additions & 3 deletions temporian/core/operators/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def cast(
events:
(2 events):
timestamps: [1. 2.]
'A': [False True]
'B': ['a' 'b']
'A': [False True]
'B': [b'a' b'b']
'C': [5 5]
...
Expand All @@ -258,7 +258,7 @@ def cast(
(2 events):
timestamps: [1. 2.]
'A': [0. 2.]
'B': ['a' 'b']
'B': [b'a' b'b']
'C': [5 5]
...
Expand Down
4 changes: 2 additions & 2 deletions temporian/core/operators/enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def enumerate(input: EventSetNode) -> EventSetNode:
indexes: [('cat', str_)]
features: [('enumerate', int64)]
events:
cat=A (4 events):
cat=b'A' (4 events):
timestamps: [-1. 2. 3. 5.]
'enumerate': [0 1 2 3]
cat=B (1 events):
cat=b'B' (1 events):
timestamps: [0.]
'enumerate': [0]
...
Expand Down
6 changes: 5 additions & 1 deletion temporian/core/operators/scalar/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@ class BaseScalarOperator(Operator):
def __init__(
self,
input: EventSetNode,
value: Union[float, int, str, bool],
value: Union[float, int, str, bytes, bool],
is_value_first: bool = False, # useful for non-commutative operators
):
super().__init__()

if isinstance(value, str):
value = value.encode()

self.value = value
self.is_value_first = is_value_first

Expand Down Expand Up @@ -69,6 +72,7 @@ def __init__(
float: [DType.FLOAT32, DType.FLOAT64],
int: [DType.INT32, DType.INT64, DType.FLOAT32, DType.FLOAT64],
str: [DType.STRING],
bytes: [DType.STRING],
bool: [
DType.BOOLEAN,
DType.INT32,
Expand Down
12 changes: 6 additions & 6 deletions temporian/core/operators/scalar/relational_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class LessScalarOperator(RelationalScalarOperator):
@compile
def equal_scalar(
input: EventSetNode,
value: Union[float, int, str, bool],
value: Union[float, int, str, bool, bytes],
) -> EventSetNode:
"""Checks for equality between a node and a scalar element-wise.
Expand Down Expand Up @@ -124,7 +124,7 @@ def equal_scalar(
@compile
def not_equal_scalar(
input: EventSetNode,
value: Union[float, int, str, bool],
value: Union[float, int, str, bytes, bool],
) -> EventSetNode:
"""Checks for differences between a node and a scalar element-wise.
Expand Down Expand Up @@ -169,7 +169,7 @@ def not_equal_scalar(
@compile
def greater_equal_scalar(
input: EventSetNode,
value: Union[float, int, str, bool],
value: Union[float, int, str, bytes, bool],
) -> EventSetNode:
"""Check if the input node is greater or equal than a scalar element-wise.
Expand Down Expand Up @@ -214,7 +214,7 @@ def greater_equal_scalar(
@compile
def less_equal_scalar(
input: EventSetNode,
value: Union[float, int, str, bool],
value: Union[float, int, str, bytes, bool],
) -> EventSetNode:
"""Check if the input node is less or equal than a scalar element-wise.
Expand Down Expand Up @@ -259,7 +259,7 @@ def less_equal_scalar(
@compile
def greater_scalar(
input: EventSetNode,
value: Union[float, int, str, bool],
value: Union[float, int, str, bytes, bool],
) -> EventSetNode:
"""Check if the input node is greater than a scalar element-wise.
Expand Down Expand Up @@ -304,7 +304,7 @@ def greater_scalar(
@compile
def less_scalar(
input: EventSetNode,
value: Union[float, int, str, bool],
value: Union[float, int, str, bytes, bool],
) -> EventSetNode:
"""Check if the input node is less than a scalar element-wise.
Expand Down
4 changes: 2 additions & 2 deletions temporian/core/operators/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def select(
events:
(2 events):
timestamps: [1. 2.]
'B': ['s' 'm']
'B': [b's' b'm']
...
>>> # Select multiple features
Expand All @@ -130,7 +130,7 @@ def select(
events:
(2 events):
timestamps: [1. 2.]
'B': ['s' 'm']
'B': [b's' b'm']
'C': [5. 5.5]
...
Expand Down
4 changes: 2 additions & 2 deletions temporian/core/operators/window/moving_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def moving_count(
indexes: [('idx', str_)]
features: [('value', int32)]
events:
idx=i2 (3 events):
idx=b'i2' (3 events):
timestamps: [0. 1. 2.]
'value': [1 2 2]
idx=i1 (3 events):
idx=b'i1' (3 events):
timestamps: [1. 2. 3.]
'value': [1 2 2]
...
Expand Down
5 changes: 4 additions & 1 deletion temporian/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def _serialize_node(
src: EventSetNode, operators: Set[base.Operator]
) -> pb.EventSetNode:
assert len(src.schema.features) == len(src.feature_nodes)
logging.info("aaaa")
logging.info(operators)
return pb.EventSetNode(
id=_identifier(src),
Expand Down Expand Up @@ -625,6 +624,8 @@ def _attribute_to_proto(
) -> pb.Operator.Attribute:
if isinstance(value, str):
return pb.Operator.Attribute(key=key, str=value)
if isinstance(value, bytes):
return pb.Operator.Attribute(key=key, bytes_=value)
if isinstance(value, bool):
# NOTE: Check this before int (isinstance(False, int) is also True)
return pb.Operator.Attribute(key=key, boolean=value)
Expand Down Expand Up @@ -664,6 +665,8 @@ def _attribute_from_proto(src: pb.Operator.Attribute) -> base.AttributeType:
return src.integer_64
if src.HasField("str"):
return src.str
if src.HasField("bytes_"):
return src.bytes_
if src.HasField("float_64"):
return src.float_64
if src.HasField("list_str"):
Expand Down
Loading

0 comments on commit 4e27d4f

Please sign in to comment.