106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
|
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
from pandas.compat import (
|
||
|
is_ci_environment,
|
||
|
is_platform_mac,
|
||
|
is_platform_windows,
|
||
|
)
|
||
|
import pandas.util._test_decorators as td
|
||
|
|
||
|
from pandas import (
|
||
|
DataFrame,
|
||
|
Series,
|
||
|
)
|
||
|
import pandas._testing as tm
|
||
|
|
||
|
# TODO(GH#44584): Mark these as pytest.mark.single_cpu
|
||
|
pytestmark = pytest.mark.skipif(
|
||
|
is_ci_environment() and (is_platform_windows() or is_platform_mac()),
|
||
|
reason="On GHA CI, Windows can fail with "
|
||
|
"'Windows fatal exception: stack overflow' "
|
||
|
"and MacOS can timeout",
|
||
|
)
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba")
|
||
|
@pytest.mark.filterwarnings("ignore")
|
||
|
# Filter warnings when parallel=True and the function can't be parallelized by Numba
|
||
|
class TestEWM:
|
||
|
def test_invalid_update(self):
|
||
|
df = DataFrame({"a": range(5), "b": range(5)})
|
||
|
online_ewm = df.head(2).ewm(0.5).online()
|
||
|
with pytest.raises(
|
||
|
ValueError,
|
||
|
match="Must call mean with update=None first before passing update",
|
||
|
):
|
||
|
online_ewm.mean(update=df.head(1))
|
||
|
|
||
|
@pytest.mark.slow
|
||
|
@pytest.mark.parametrize(
|
||
|
"obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")]
|
||
|
)
|
||
|
def test_online_vs_non_online_mean(
|
||
|
self, obj, nogil, parallel, nopython, adjust, ignore_na
|
||
|
):
|
||
|
expected = obj.ewm(0.5, adjust=adjust, ignore_na=ignore_na).mean()
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
|
||
|
online_ewm = (
|
||
|
obj.head(2)
|
||
|
.ewm(0.5, adjust=adjust, ignore_na=ignore_na)
|
||
|
.online(engine_kwargs=engine_kwargs)
|
||
|
)
|
||
|
# Test resetting once
|
||
|
for _ in range(2):
|
||
|
result = online_ewm.mean()
|
||
|
tm.assert_equal(result, expected.head(2))
|
||
|
|
||
|
result = online_ewm.mean(update=obj.tail(3))
|
||
|
tm.assert_equal(result, expected.tail(3))
|
||
|
|
||
|
online_ewm.reset()
|
||
|
|
||
|
@pytest.mark.xfail(raises=NotImplementedError)
|
||
|
@pytest.mark.parametrize(
|
||
|
"obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")]
|
||
|
)
|
||
|
def test_update_times_mean(
|
||
|
self, obj, nogil, parallel, nopython, adjust, ignore_na, halflife_with_times
|
||
|
):
|
||
|
times = Series(
|
||
|
np.array(
|
||
|
["2020-01-01", "2020-01-05", "2020-01-07", "2020-01-17", "2020-01-21"],
|
||
|
dtype="datetime64",
|
||
|
)
|
||
|
)
|
||
|
expected = obj.ewm(
|
||
|
0.5,
|
||
|
adjust=adjust,
|
||
|
ignore_na=ignore_na,
|
||
|
times=times,
|
||
|
halflife=halflife_with_times,
|
||
|
).mean()
|
||
|
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
online_ewm = (
|
||
|
obj.head(2)
|
||
|
.ewm(
|
||
|
0.5,
|
||
|
adjust=adjust,
|
||
|
ignore_na=ignore_na,
|
||
|
times=times.head(2),
|
||
|
halflife=halflife_with_times,
|
||
|
)
|
||
|
.online(engine_kwargs=engine_kwargs)
|
||
|
)
|
||
|
# Test resetting once
|
||
|
for _ in range(2):
|
||
|
result = online_ewm.mean()
|
||
|
tm.assert_equal(result, expected.head(2))
|
||
|
|
||
|
result = online_ewm.mean(update=obj.tail(3), update_times=times.tail(3))
|
||
|
tm.assert_equal(result, expected.tail(3))
|
||
|
|
||
|
online_ewm.reset()
|