Skip to content

Commit

Permalink
Fix an issue with creating a series from scalar when `dtype='category…
Browse files Browse the repository at this point in the history
…'` (rapidsai#15476)

## Description
When `dtype='category'` we seem to error:
```

File "/pyenv/versions/3.11.9/lib/python3.11/site-packages/cuml/preprocessing/LabelEncoder.py", line 218, in transform
2024-04-05T19:37:35.8255262Z E                 y = cudf.Series('a', dtype="category")
2024-04-05T19:37:35.8257445Z E                 ^^^^^^^^^^^^^^^^^
2024-04-05T19:37:35.8260865Z E               File "/pyenv/versions/3.11.9/lib/python3.11/site-packages/nvtx/nvtx.py", line 116, in inner
2024-04-05T19:37:35.8264174Z E                 result = func(*args, **kwargs)
2024-04-05T19:37:35.8266324Z E                 ^^^^^^^^^^^^^^^^^
2024-04-05T19:37:35.8270003Z E               File "/pyenv/versions/3.11.9/lib/python3.11/site-packages/cudf/core/series.py", line 648, in __init__
2024-04-05T19:37:35.8273382Z E                 column = as_column(
2024-04-05T19:37:35.8275420Z E                 ^^^^^^^^^^^^^^^^^
2024-04-05T19:37:35.8279989Z E               File "/pyenv/versions/3.11.9/lib/python3.11/site-packages/cudf/core/column/column.py", line 2022, in as_column
2024-04-05T19:37:35.8281584Z E                 arbitrary = cudf.Scalar(arbitrary, dtype=dtype)
2024-04-05T19:37:35.8282461Z E                 ^^^^^^^^^^^^^^^^^
2024-04-05T19:37:35.8283768Z E               File "/pyenv/versions/3.11.9/lib/python3.11/site-packages/cudf/core/scalar.py", line 57, in __call__
2024-04-05T19:37:35.8285137Z E                 obj = super().__call__(value, dtype=dtype)
2024-04-05T19:37:35.8285959Z E                 ^^^^^^^^^^^^^^^^^
2024-04-05T19:37:35.8287757Z E               File "/pyenv/versions/3.11.9/lib/python3.11/site-packages/cudf/core/scalar.py", line 128, in __init__
2024-04-05T19:37:35.8289232Z E                 self._host_value, self._host_dtype = self._preprocess_host_value(
2024-04-05T19:37:35.8290183Z E                 ^^^^^^^^^^^^^^^^^
2024-04-05T19:37:35.8291705Z E               File "/pyenv/versions/3.11.9/lib/python3.11/site-packages/cudf/core/scalar.py", line 222, in _preprocess_host_value
2024-04-05T19:37:35.8293212Z E                 value = to_cudf_compatible_scalar(value, dtype=dtype)
2024-04-05T19:37:35.8294438Z E                 ^^^^^^^^^^^^^^^^^
2024-04-05T19:37:35.8296026Z E               File "/pyenv/versions/3.11.9/lib/python3.11/site-packages/cudf/utils/dtypes.py", line 257, in to_cudf_compatible_scalar
2024-04-05T19:37:35.8297604Z E                 if isinstance(val, str) and np.dtype(dtype).kind == "M":
2024-04-05T19:37:35.8298543Z E                 ^^^^^^^^^^^^^^^^^
2024-04-05T19:37:35.8308752Z E             TypeError: data type 'category' not understood
```
## Checklist
- [x] I am familiar with the [Contributing
Guidelines](https://github.com/rapidsai/cudf/blob/HEAD/CONTRIBUTING.md).
- [x] New or existing tests cover these changes.
- [x] The documentation is up to date with these changes.
  • Loading branch information
galipremsagar authored Apr 8, 2024
1 parent 35f818b commit e6cfd45
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2009,7 +2009,7 @@ def as_column(
length = 1
elif length < 0:
raise ValueError(f"{length=} must be >=0.")
if isinstance(arbitrary, pd.Interval):
if isinstance(arbitrary, pd.Interval) or _is_categorical_dtype(dtype):
# No cudf.Scalar support yet
return as_column(
pd.Series([arbitrary] * length),
Expand Down
8 changes: 8 additions & 0 deletions python/cudf/cudf/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,3 +846,11 @@ def test_empty_series_category_cast(ordered):

assert_eq(expected, actual)
assert_eq(expected.dtype.ordered, actual.dtype.ordered)


@pytest.mark.parametrize("scalar", [1, "a", None, 10.2])
def test_cat_from_scalar(scalar):
ps = pd.Series(scalar, dtype="category")
gs = cudf.Series(scalar, dtype="category")

assert_eq(ps, gs)

0 comments on commit e6cfd45

Please sign in to comment.