Skip to content

Commit

Permalink
Allow apply udf to reference global modules in cudf.pandas (rapidsai#…
Browse files Browse the repository at this point in the history
…15569)

closes rapidsai#15548

`_replace_closurevars` creates a new function by replacing objects with their fast versions. When creating the new function, it populates `globals` from the result of `inspect.getclosurevars`, but it don't think it comprehensively returns _all_ the globals accessible to the function (`function.__globals__`)

To minimize the change, the "fast globals" are still sourced from `inspect.getclosurevars`, and those update the `old_function.__globals__` when creating a new function.

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: rapidsai#15569
  • Loading branch information
mroeschke authored Apr 19, 2024
1 parent 088be5a commit 21350fc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/cudf/cudf/pandas/fast_slow_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def _replace_closurevars(
if any(c == types.CellType() for c in f.__closure__):
return f

f_nonlocals, f_globals, f_builtins, _ = inspect.getclosurevars(f)
f_nonlocals, f_globals, _, _ = inspect.getclosurevars(f)

g_globals = _transform_arg(f_globals, attribute_name, seen)
g_nonlocals = _transform_arg(f_nonlocals, attribute_name, seen)
Expand All @@ -1121,11 +1121,14 @@ def _replace_closurevars(
return f

g_closure = tuple(types.CellType(val) for val in g_nonlocals.values())
g_globals["__builtins__"] = f_builtins

# https://github.com/rapidsai/cudf/issues/15548
new_g_globals = f.__globals__.copy()
new_g_globals.update(g_globals)

g = types.FunctionType(
f.__code__,
g_globals,
new_g_globals,
name=f.__name__,
argdefs=f.__defaults__,
closure=g_closure,
Expand Down
12 changes: 12 additions & 0 deletions python/cudf/cudf_pandas_tests/test_cudf_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,3 +1208,15 @@ def test_pickle_groupby(dataframe):
def test_isinstance_base_offset():
offset = xpd.tseries.frequencies.to_offset("1s")
assert isinstance(offset, xpd.tseries.offsets.BaseOffset)


def test_apply_slow_path_udf_references_global_module():
def my_apply(df, unused):
# `datetime` Raised `KeyError: __import__`
datetime.datetime.strptime(df["Minute"], "%H:%M:%S")
return pd.to_numeric(1)

df = xpd.DataFrame({"Minute": ["09:00:00"]})
result = df.apply(my_apply, axis=1, unused=True)
expected = xpd.Series([1])
tm.assert_series_equal(result, expected)

0 comments on commit 21350fc

Please sign in to comment.