-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic statistics allow computation on sparse data and add test #2095
base: main
Are you sure you want to change the base?
Changes from 7 commits
d8bf640
f1fb204
49804f9
353cf0c
7df2a3f
793ab23
f5210ea
168a897
8f50d85
ecd34e3
49f9ad7
32955d4
2bf51fb
f5eb5c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
import numpy as np | ||
import pytest | ||
from numpy.testing import assert_allclose | ||
from scipy.sparse import csr_matrix | ||
from sklearn.datasets import make_blobs | ||
|
||
from daal4py.sklearn._utils import daal_check_version | ||
from onedal.basic_statistics.tests.test_basic_statistics import ( | ||
|
@@ -28,6 +30,7 @@ | |
from onedal.tests.utils._dataframes_support import ( | ||
_convert_to_dataframe, | ||
get_dataframes_and_queues, | ||
get_queues, | ||
) | ||
from sklearnex.basic_statistics import BasicStatistics | ||
|
||
|
@@ -178,6 +181,55 @@ def test_multiple_options_on_random_data( | |
assert_allclose(gtr_sum, res_sum, atol=tol) | ||
|
||
|
||
@pytest.mark.parametrize("queue", get_queues()) | ||
@pytest.mark.parametrize("row_count", [100, 1000]) | ||
@pytest.mark.parametrize("column_count", [10, 100]) | ||
@pytest.mark.parametrize("weighted", [True, False]) | ||
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) | ||
def test_multiple_options_on_random_sparse_data( | ||
queue, row_count, column_count, weighted, dtype | ||
): | ||
seed = 77 | ||
random_state = 42 | ||
gen = np.random.default_rng(seed) | ||
X, _ = make_blobs( | ||
n_samples=row_count, n_features=column_count, random_state=random_state | ||
) | ||
density = 0.05 | ||
X_sparse = csr_matrix(X * (np.random.rand(*X.shape) < density)) | ||
X_dense = X_sparse.toarray() | ||
|
||
if weighted: | ||
weights = gen.uniform(low=-0.5, high=1.0, size=row_count) | ||
weights = weights.astype(dtype=dtype) | ||
basicstat = BasicStatistics(result_options=["mean", "max", "sum"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to onedal tests, need to exclude "max" at it contains bugs: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@olegkkruglov please message out a link to the ticket associated with this error (just to make sure it wasn't lost) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if I have it. This skip was added not by me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I observed that issue as well, temporarily removed the "max" in tests for this PR. |
||
|
||
if weighted: | ||
result = basicstat.fit(X_sparse, sample_weight=weights) | ||
else: | ||
result = basicstat.fit(X_sparse) | ||
|
||
res_mean, res_max, res_sum = result.mean, result.max, result.sum | ||
if weighted: | ||
weighted_data = np.diag(weights) @ X_dense | ||
gtr_mean, gtr_max, gtr_sum = ( | ||
expected_mean(weighted_data), | ||
expected_max(weighted_data), | ||
expected_sum(weighted_data), | ||
) | ||
else: | ||
gtr_mean, gtr_max, gtr_sum = ( | ||
expected_mean(X_dense), | ||
expected_max(X_dense), | ||
expected_sum(X_dense), | ||
) | ||
|
||
tol = 5e-4 if res_mean.dtype == np.float32 else 1e-7 | ||
assert_allclose(gtr_mean, res_mean, atol=tol) | ||
assert_allclose(gtr_max, res_max, atol=tol) | ||
assert_allclose(gtr_sum, res_sum, atol=tol) | ||
|
||
|
||
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) | ||
@pytest.mark.parametrize("row_count", [100, 1000]) | ||
@pytest.mark.parametrize("column_count", [10, 100]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use
_get_dataframes_and_queues
insteadThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sparse data can't work with dataframes