249 lines
8.0 KiB
Python
249 lines
8.0 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
from typing import Any
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
import pandas as pd
|
||
|
import pandas._testing as tm
|
||
|
|
||
|
# integer dtypes
|
||
|
arrays = [pd.array([1, 2, 3, None], dtype=dtype) for dtype in tm.ALL_INT_EA_DTYPES]
|
||
|
scalars: list[Any] = [2] * len(arrays)
|
||
|
# floating dtypes
|
||
|
arrays += [pd.array([0.1, 0.2, 0.3, None], dtype=dtype) for dtype in tm.FLOAT_EA_DTYPES]
|
||
|
scalars += [0.2, 0.2]
|
||
|
# boolean
|
||
|
arrays += [pd.array([True, False, True, None], dtype="boolean")]
|
||
|
scalars += [False]
|
||
|
|
||
|
|
||
|
@pytest.fixture(params=zip(arrays, scalars), ids=[a.dtype.name for a in arrays])
|
||
|
def data(request):
|
||
|
"""Fixture returning parametrized (array, scalar) tuple.
|
||
|
|
||
|
Used to test equivalence of scalars, numpy arrays with array ops, and the
|
||
|
equivalence of DataFrame and Series ops.
|
||
|
"""
|
||
|
return request.param
|
||
|
|
||
|
|
||
|
def check_skip(data, op_name):
|
||
|
if isinstance(data.dtype, pd.BooleanDtype) and "sub" in op_name:
|
||
|
pytest.skip("subtract not implemented for boolean")
|
||
|
|
||
|
|
||
|
def is_bool_not_implemented(data, op_name):
|
||
|
# match non-masked behavior
|
||
|
return data.dtype.kind == "b" and op_name.strip("_").lstrip("r") in [
|
||
|
"pow",
|
||
|
"truediv",
|
||
|
"floordiv",
|
||
|
]
|
||
|
|
||
|
|
||
|
# Test equivalence of scalars, numpy arrays with array ops
|
||
|
# -----------------------------------------------------------------------------
|
||
|
|
||
|
|
||
|
def test_array_scalar_like_equivalence(data, all_arithmetic_operators):
|
||
|
data, scalar = data
|
||
|
op = tm.get_op_from_name(all_arithmetic_operators)
|
||
|
check_skip(data, all_arithmetic_operators)
|
||
|
|
||
|
scalar_array = pd.array([scalar] * len(data), dtype=data.dtype)
|
||
|
|
||
|
# TODO also add len-1 array (np.array([scalar], dtype=data.dtype.numpy_dtype))
|
||
|
for scalar in [scalar, data.dtype.type(scalar)]:
|
||
|
if is_bool_not_implemented(data, all_arithmetic_operators):
|
||
|
msg = "operator '.*' not implemented for bool dtypes"
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
op(data, scalar)
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
op(data, scalar_array)
|
||
|
else:
|
||
|
result = op(data, scalar)
|
||
|
expected = op(data, scalar_array)
|
||
|
tm.assert_extension_array_equal(result, expected)
|
||
|
|
||
|
|
||
|
def test_array_NA(data, all_arithmetic_operators):
|
||
|
data, _ = data
|
||
|
op = tm.get_op_from_name(all_arithmetic_operators)
|
||
|
check_skip(data, all_arithmetic_operators)
|
||
|
|
||
|
scalar = pd.NA
|
||
|
scalar_array = pd.array([pd.NA] * len(data), dtype=data.dtype)
|
||
|
|
||
|
mask = data._mask.copy()
|
||
|
|
||
|
if is_bool_not_implemented(data, all_arithmetic_operators):
|
||
|
msg = "operator '.*' not implemented for bool dtypes"
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
op(data, scalar)
|
||
|
# GH#45421 check op doesn't alter data._mask inplace
|
||
|
tm.assert_numpy_array_equal(mask, data._mask)
|
||
|
return
|
||
|
|
||
|
result = op(data, scalar)
|
||
|
# GH#45421 check op doesn't alter data._mask inplace
|
||
|
tm.assert_numpy_array_equal(mask, data._mask)
|
||
|
|
||
|
expected = op(data, scalar_array)
|
||
|
tm.assert_numpy_array_equal(mask, data._mask)
|
||
|
|
||
|
tm.assert_extension_array_equal(result, expected)
|
||
|
|
||
|
|
||
|
def test_numpy_array_equivalence(data, all_arithmetic_operators):
|
||
|
data, scalar = data
|
||
|
op = tm.get_op_from_name(all_arithmetic_operators)
|
||
|
check_skip(data, all_arithmetic_operators)
|
||
|
|
||
|
numpy_array = np.array([scalar] * len(data), dtype=data.dtype.numpy_dtype)
|
||
|
pd_array = pd.array(numpy_array, dtype=data.dtype)
|
||
|
|
||
|
if is_bool_not_implemented(data, all_arithmetic_operators):
|
||
|
msg = "operator '.*' not implemented for bool dtypes"
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
op(data, numpy_array)
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
op(data, pd_array)
|
||
|
return
|
||
|
|
||
|
result = op(data, numpy_array)
|
||
|
expected = op(data, pd_array)
|
||
|
tm.assert_extension_array_equal(result, expected)
|
||
|
|
||
|
|
||
|
# Test equivalence with Series and DataFrame ops
|
||
|
# -----------------------------------------------------------------------------
|
||
|
|
||
|
|
||
|
def test_frame(data, all_arithmetic_operators):
|
||
|
data, scalar = data
|
||
|
op = tm.get_op_from_name(all_arithmetic_operators)
|
||
|
check_skip(data, all_arithmetic_operators)
|
||
|
|
||
|
# DataFrame with scalar
|
||
|
df = pd.DataFrame({"A": data})
|
||
|
|
||
|
if is_bool_not_implemented(data, all_arithmetic_operators):
|
||
|
msg = "operator '.*' not implemented for bool dtypes"
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
op(df, scalar)
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
op(data, scalar)
|
||
|
return
|
||
|
|
||
|
result = op(df, scalar)
|
||
|
expected = pd.DataFrame({"A": op(data, scalar)})
|
||
|
tm.assert_frame_equal(result, expected)
|
||
|
|
||
|
|
||
|
def test_series(data, all_arithmetic_operators):
|
||
|
data, scalar = data
|
||
|
op = tm.get_op_from_name(all_arithmetic_operators)
|
||
|
check_skip(data, all_arithmetic_operators)
|
||
|
|
||
|
ser = pd.Series(data)
|
||
|
|
||
|
others = [
|
||
|
scalar,
|
||
|
np.array([scalar] * len(data), dtype=data.dtype.numpy_dtype),
|
||
|
pd.array([scalar] * len(data), dtype=data.dtype),
|
||
|
pd.Series([scalar] * len(data), dtype=data.dtype),
|
||
|
]
|
||
|
|
||
|
for other in others:
|
||
|
if is_bool_not_implemented(data, all_arithmetic_operators):
|
||
|
msg = "operator '.*' not implemented for bool dtypes"
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
op(ser, other)
|
||
|
|
||
|
else:
|
||
|
result = op(ser, other)
|
||
|
expected = pd.Series(op(data, other))
|
||
|
tm.assert_series_equal(result, expected)
|
||
|
|
||
|
|
||
|
# Test generic characteristics / errors
|
||
|
# -----------------------------------------------------------------------------
|
||
|
|
||
|
|
||
|
def test_error_invalid_object(data, all_arithmetic_operators):
|
||
|
data, _ = data
|
||
|
|
||
|
op = all_arithmetic_operators
|
||
|
opa = getattr(data, op)
|
||
|
|
||
|
# 2d -> return NotImplemented
|
||
|
result = opa(pd.DataFrame({"A": data}))
|
||
|
assert result is NotImplemented
|
||
|
|
||
|
msg = r"can only perform ops with 1-d structures"
|
||
|
with pytest.raises(NotImplementedError, match=msg):
|
||
|
opa(np.arange(len(data)).reshape(-1, len(data)))
|
||
|
|
||
|
|
||
|
def test_error_len_mismatch(data, all_arithmetic_operators):
|
||
|
# operating with a list-like with non-matching length raises
|
||
|
data, scalar = data
|
||
|
op = tm.get_op_from_name(all_arithmetic_operators)
|
||
|
|
||
|
other = [scalar] * (len(data) - 1)
|
||
|
|
||
|
err = ValueError
|
||
|
msg = "|".join(
|
||
|
[
|
||
|
r"operands could not be broadcast together with shapes \(3,\) \(4,\)",
|
||
|
r"operands could not be broadcast together with shapes \(4,\) \(3,\)",
|
||
|
]
|
||
|
)
|
||
|
if data.dtype.kind == "b" and all_arithmetic_operators.strip("_") in [
|
||
|
"sub",
|
||
|
"rsub",
|
||
|
]:
|
||
|
err = TypeError
|
||
|
msg = (
|
||
|
r"numpy boolean subtract, the `\-` operator, is not supported, use "
|
||
|
r"the bitwise_xor, the `\^` operator, or the logical_xor function instead"
|
||
|
)
|
||
|
elif is_bool_not_implemented(data, all_arithmetic_operators):
|
||
|
msg = "operator '.*' not implemented for bool dtypes"
|
||
|
err = NotImplementedError
|
||
|
|
||
|
for other in [other, np.array(other)]:
|
||
|
with pytest.raises(err, match=msg):
|
||
|
op(data, other)
|
||
|
|
||
|
s = pd.Series(data)
|
||
|
with pytest.raises(err, match=msg):
|
||
|
op(s, other)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("op", ["__neg__", "__abs__", "__invert__"])
|
||
|
def test_unary_op_does_not_propagate_mask(data, op):
|
||
|
# https://github.com/pandas-dev/pandas/issues/39943
|
||
|
data, _ = data
|
||
|
ser = pd.Series(data)
|
||
|
|
||
|
if op == "__invert__" and data.dtype.kind == "f":
|
||
|
# we follow numpy in raising
|
||
|
msg = "ufunc 'invert' not supported for the input types"
|
||
|
with pytest.raises(TypeError, match=msg):
|
||
|
getattr(ser, op)()
|
||
|
with pytest.raises(TypeError, match=msg):
|
||
|
getattr(data, op)()
|
||
|
with pytest.raises(TypeError, match=msg):
|
||
|
# Check that this is still the numpy behavior
|
||
|
getattr(data._data, op)()
|
||
|
|
||
|
return
|
||
|
|
||
|
result = getattr(ser, op)()
|
||
|
expected = result.copy(deep=True)
|
||
|
ser[0] = None
|
||
|
tm.assert_series_equal(result, expected)
|