74 lines
3.1 KiB
Python
74 lines
3.1 KiB
Python
|
import pytest
|
||
|
|
||
|
import pandas.util._test_decorators as td
|
||
|
|
||
|
from pandas import (
|
||
|
DataFrame,
|
||
|
Series,
|
||
|
)
|
||
|
import pandas._testing as tm
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba")
|
||
|
@pytest.mark.filterwarnings("ignore")
|
||
|
# Filter warnings when parallel=True and the function can't be parallelized by Numba
|
||
|
class TestEngine:
|
||
|
def test_cython_vs_numba_frame(
|
||
|
self, sort, nogil, parallel, nopython, numba_supported_reductions
|
||
|
):
|
||
|
func, kwargs = numba_supported_reductions
|
||
|
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
gb = df.groupby("a", sort=sort)
|
||
|
result = getattr(gb, func)(
|
||
|
engine="numba", engine_kwargs=engine_kwargs, **kwargs
|
||
|
)
|
||
|
expected = getattr(gb, func)(**kwargs)
|
||
|
# check_dtype can be removed if GH 44952 is addressed
|
||
|
check_dtype = func not in ("sum", "min", "max")
|
||
|
tm.assert_frame_equal(result, expected, check_dtype=check_dtype)
|
||
|
|
||
|
def test_cython_vs_numba_getitem(
|
||
|
self, sort, nogil, parallel, nopython, numba_supported_reductions
|
||
|
):
|
||
|
func, kwargs = numba_supported_reductions
|
||
|
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
gb = df.groupby("a", sort=sort)["c"]
|
||
|
result = getattr(gb, func)(
|
||
|
engine="numba", engine_kwargs=engine_kwargs, **kwargs
|
||
|
)
|
||
|
expected = getattr(gb, func)(**kwargs)
|
||
|
# check_dtype can be removed if GH 44952 is addressed
|
||
|
check_dtype = func not in ("sum", "min", "max")
|
||
|
tm.assert_series_equal(result, expected, check_dtype=check_dtype)
|
||
|
|
||
|
def test_cython_vs_numba_series(
|
||
|
self, sort, nogil, parallel, nopython, numba_supported_reductions
|
||
|
):
|
||
|
func, kwargs = numba_supported_reductions
|
||
|
ser = Series(range(3), index=[1, 2, 1], name="foo")
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
gb = ser.groupby(level=0, sort=sort)
|
||
|
result = getattr(gb, func)(
|
||
|
engine="numba", engine_kwargs=engine_kwargs, **kwargs
|
||
|
)
|
||
|
expected = getattr(gb, func)(**kwargs)
|
||
|
# check_dtype can be removed if GH 44952 is addressed
|
||
|
check_dtype = func not in ("sum", "min", "max")
|
||
|
tm.assert_series_equal(result, expected, check_dtype=check_dtype)
|
||
|
|
||
|
def test_as_index_false_unsupported(self, numba_supported_reductions):
|
||
|
func, kwargs = numba_supported_reductions
|
||
|
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
|
||
|
gb = df.groupby("a", as_index=False)
|
||
|
with pytest.raises(NotImplementedError, match="as_index=False"):
|
||
|
getattr(gb, func)(engine="numba", **kwargs)
|
||
|
|
||
|
def test_axis_1_unsupported(self, numba_supported_reductions):
|
||
|
func, kwargs = numba_supported_reductions
|
||
|
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
|
||
|
gb = df.groupby("a", axis=1)
|
||
|
with pytest.raises(NotImplementedError, match="axis=1"):
|
||
|
getattr(gb, func)(engine="numba", **kwargs)
|