Skip to content

Commit 42d6ad3

Browse files
authored
fix: return 0-D array for full reductions per Array API standard (#932)
1 parent 3589a7c commit 42d6ad3

6 files changed

Lines changed: 102 additions & 20 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,6 @@ docs/examples_ipynb/
8787
# Envs
8888
.pixi/
8989
.venv/
90+
*.pem
91+
*.db
92+
array-api-tests/

sparse/numba_backend/_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,7 @@ def _einsum_single(lhs, rhs, operand):
13441344

13451345
if lhs == rhs:
13461346
if not rhs:
1347-
# ensure scalar output
1347+
# full contraction — return 0-D array per the Array API standard
13481348
return operand.sum()
13491349
return operand
13501350

@@ -1390,8 +1390,9 @@ def _einsum_single(lhs, rhs, operand):
13901390
new_data = operand.data
13911391

13921392
if not rhs:
1393-
# scalar output - match numpy behaviour by not wrapping as array
1394-
return new_data.sum()
1393+
# full contraction — return 0-D COO array per the Array API standard
1394+
data = np.asarray(new_data.sum())
1395+
return COO.from_numpy(data)
13951396

13961397
return to_output_format(COO(new_coords, new_data, shape=new_shape, has_duplicates=True))
13971398

sparse/numba_backend/_coo/common.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
check_consistent_fill_value,
1515
check_zero_fill_value,
1616
is_unsigned_dtype,
17-
isscalar,
1817
normalize_axis,
1918
)
2019

@@ -416,6 +415,19 @@ def nanmean(x, axis=None, keepdims=False, dtype=None, out=None):
416415
return (num / den).astype(dtype if dtype is not None else x.dtype)
417416

418417

418+
def _contains_nan(ar):
419+
"""Check if a SparseArray or scalar contains any NaN values.
420+
Checks dtype first (fast), then fill_value, then data (slow).
421+
"""
422+
if isinstance(ar, SparseArray):
423+
if not np.issubdtype(ar.dtype, np.floating):
424+
return False
425+
if ar.nnz != ar.size and np.isnan(ar.fill_value):
426+
return True
427+
return np.isnan(ar.data).any()
428+
return np.isnan(ar)
429+
430+
419431
def nanmax(x, axis=None, keepdims=False, dtype=None, out=None):
420432
"""
421433
Maximize along the given axes, skipping `NaN` values. Uses all axes by default.
@@ -446,7 +458,7 @@ def nanmax(x, axis=None, keepdims=False, dtype=None, out=None):
446458

447459
ar = x.reduce(np.fmax, axis=axis, keepdims=keepdims, dtype=dtype)
448460

449-
if (isscalar(ar) and np.isnan(ar)) or np.isnan(ar.data).any():
461+
if _contains_nan(ar):
450462
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=1)
451463

452464
return ar
@@ -482,7 +494,7 @@ def nanmin(x, axis=None, keepdims=False, dtype=None, out=None):
482494

483495
ar = x.reduce(np.fmin, axis=axis, keepdims=keepdims, dtype=dtype)
484496

485-
if (isscalar(ar) and np.isnan(ar)) or np.isnan(ar.data).any():
497+
if _contains_nan(ar):
486498
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=1)
487499

488500
return ar
@@ -901,7 +913,7 @@ def diagonalize(a, axis=0):
901913
>>> a = sparse.random((3, 3, 3, 3, 3), density=0.3)
902914
>>> a_diag = sparse.diagonalize(a, axis=2)
903915
>>> (sparse.diagonal(a_diag, axis1=2, axis2=5) == a.transpose([0, 1, 3, 4, 2])).all()
904-
np.True_
916+
<COO: shape=(), dtype=bool, nnz=0, fill_value=True>
905917
906918
Returns
907919
-------

sparse/numba_backend/_sparse_array.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def reduce(self, method, axis=(0,), keepdims=False, **kwargs):
400400
axis = (axis,)
401401
out = self._reduce_calc(method, axis, keepdims, **kwargs)
402402
if len(out) == 1:
403-
return out[0]
403+
return out[0] if isinstance(out[0], SparseArray) else type(self).from_numpy(np.array(out[0]))
404404
data, counts, axis, n_cols, arr_attrs = out
405405
result_fill_value = self.fill_value
406406
if reduce_super_ufunc is None:
@@ -422,7 +422,9 @@ def reduce(self, method, axis=(0,), keepdims=False, **kwargs):
422422
out = out.reshape(shape)
423423

424424
if out.ndim == 0:
425-
return out[()]
425+
# Return a 0-D array per the Array API standard.
426+
# The element value becomes the fill_value (nnz=0 is correct for 0-D).
427+
return type(self).from_numpy(out.todense())
426428

427429
return out
428430

@@ -689,7 +691,7 @@ def mean(self, axis=None, keepdims=False, dtype=None, out=None):
689691
mean along all axes.
690692
691693
>>> s.mean()
692-
np.float64(0.5)
694+
<COO: shape=(), dtype=float64, nnz=0, fill_value=0.5>
693695
"""
694696

695697
if axis is None:
@@ -709,10 +711,8 @@ def mean(self, axis=None, keepdims=False, dtype=None, out=None):
709711

710712
num = self.sum(axis=axis, keepdims=keepdims, dtype=inter_dtype)
711713

712-
if num.ndim:
713-
out = np.true_divide(num, den, casting="unsafe")
714-
return out.astype(dtype) if out.dtype != dtype else out
715-
return np.divide(num, den, dtype=dtype, out=out)
714+
out = np.true_divide(num, den, casting="unsafe")
715+
return out.astype(dtype) if out.dtype != dtype else out
716716

717717
def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
718718
"""
@@ -769,7 +769,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
769769
variance along all axes.
770770
771771
>>> s.var()
772-
np.float64(0.5)
772+
<COO: shape=(), dtype=float64, nnz=0, fill_value=0.5>
773773
"""
774774
axis = normalize_axis(axis, self.ndim)
775775

@@ -803,7 +803,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
803803

804804
ret = ret[...]
805805
np.divide(ret, rcount, out=ret, casting="unsafe")
806-
return ret[()]
806+
return ret
807807

808808
def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
809809
"""

sparse/numba_backend/tests/test_array_function.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,69 @@ def test_asarray(self, input, dtype, format):
136136

137137
if isinstance(input, SparseArray):
138138
assert sparse.asarray(input).__class__ is input.__class__
139+
140+
141+
class TestArrayAPIReductions:
142+
"""
143+
Array API standard compliance: reductions over the entire array must return
144+
a zero-dimensional array, not a NumPy scalar.
145+
146+
See: https://github.com/pydata/sparse/issues/921
147+
"""
148+
149+
@pytest.mark.parametrize("format", ["coo", "gcxs"])
150+
@pytest.mark.parametrize(
151+
"fn, expected",
152+
[
153+
(sparse.sum, 2.0),
154+
(sparse.max, 1.0),
155+
(sparse.min, 0.0),
156+
(sparse.prod, 0.0),
157+
(sparse.mean, 0.5),
158+
],
159+
)
160+
def test_full_reduction_returns_0d_array(self, fn, expected, format):
161+
x = sparse.asarray(np.eye(2), format=format)
162+
result = fn(x)
163+
assert result.ndim == 0, f"{fn.__name__}() over entire array returned ndim={result.ndim}, expected 0-D array"
164+
assert isinstance(result, SparseArray), (
165+
f"{fn.__name__}() returned {type(result).__name__}, expected a SparseArray"
166+
)
167+
assert abs(float(result) - expected) < 1e-9, f"{fn.__name__}() returned {float(result)}, expected {expected}"
168+
169+
@pytest.mark.parametrize("fn", [sparse.any, sparse.all])
170+
def test_boolean_reduction_returns_0d_array(self, fn):
171+
x = sparse.asarray(np.eye(2), format="coo")
172+
result = fn(x)
173+
assert result.ndim == 0, f"{fn.__name__}() returned ndim={result.ndim}, expected 0-D array"
174+
assert isinstance(result, SparseArray), (
175+
f"{fn.__name__}() returned {type(result).__name__}, expected a SparseArray"
176+
)
177+
178+
def test_partial_reduction_still_returns_nd_array(self):
179+
"""Axis-specific reductions must still return N-D sparse arrays."""
180+
x = sparse.asarray(np.eye(2), format="coo")
181+
182+
result_ax0 = sparse.sum(x, axis=0)
183+
assert result_ax0.shape == (2,), f"Expected shape (2,), got {result_ax0.shape}"
184+
assert isinstance(result_ax0, SparseArray)
185+
186+
result_ax1 = sparse.sum(x, axis=1)
187+
assert result_ax1.shape == (2,), f"Expected shape (2,), got {result_ax1.shape}"
188+
assert isinstance(result_ax1, SparseArray)
189+
190+
def test_keepdims_full_reduction(self):
191+
"""keepdims=True must preserve all dimensions as size-1."""
192+
x = sparse.asarray(np.eye(2), format="coo")
193+
result = sparse.sum(x, keepdims=True)
194+
assert result.shape == (1, 1), f"Expected shape (1, 1), got {result.shape}"
195+
assert isinstance(result, SparseArray)
196+
197+
@pytest.mark.parametrize("format", ["coo", "gcxs"])
198+
def test_1d_full_reduction_returns_0d_array(self, format):
199+
"""1-D input fully reduced must also give a 0-D array."""
200+
x = sparse.asarray(np.array([1.0, 2.0, 3.0]), format=format)
201+
result = sparse.sum(x)
202+
assert result.ndim == 0, f"Expected 0-D array, got ndim={result.ndim}"
203+
assert isinstance(result, SparseArray)
204+
assert abs(float(result) - 6.0) < 1e-9

sparse/numba_backend/tests/test_einsum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def test_einsum(subscripts, density):
9292
numpy_out = np.einsum(subscripts, *(s.todense() for s in arrays))
9393

9494
if not numpy_out.shape:
95-
# scalar output
96-
assert np.allclose(numpy_out, sparse_out)
95+
# scalar output — sparse_out is a 0-D COO per the Array API standard
96+
assert np.allclose(numpy_out, sparse_out.todense())
9797
else:
9898
# array output
9999
assert np.allclose(numpy_out, sparse_out.todense())
@@ -108,8 +108,8 @@ def test_einsum_nosubscript(input, density):
108108
numpy_out = np.einsum(*(s.todense() for s in arrays), *input)
109109

110110
if not numpy_out.shape:
111-
# scalar output
112-
assert np.allclose(numpy_out, sparse_out)
111+
# scalar output — sparse_out is a 0-D COO per the Array API standard
112+
assert np.allclose(numpy_out, sparse_out.todense())
113113
else:
114114
# array output
115115
assert np.allclose(numpy_out, sparse_out.todense())

0 commit comments

Comments
 (0)