Skip to content

Commit

Permalink
Add tests for Dataset.rebuild_with_new_data()
Browse files Browse the repository at this point in the history
  • Loading branch information
chezou committed Aug 15, 2024
1 parent ea33c05 commit c3d1625
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 8 deletions.
6 changes: 3 additions & 3 deletions rectools/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def get_hot_item_features(self) -> tp.Optional[Features]:
return None
return self.item_features.take(range(self.n_hot_items))

def get_hot_users(self) -> InternalIdsArray:
def get_hot_users_internal(self) -> InternalIdsArray:
"""Return internal ids of hot users."""
return self.interactions.df[Columns.User].unique()

def get_hot_items(self) -> InternalIdsArray:
def get_hot_items_internal(self) -> InternalIdsArray:
"""Return internal ids of hot items."""
return self.interactions.df[Columns.Item].unique()

Expand Down Expand Up @@ -259,7 +259,7 @@ def get_raw_interactions(self, include_weight: bool = True, include_datetime: bo
"""
return self.interactions.to_external(self.user_id_map, self.item_id_map, include_weight, include_datetime)

def construct_new_datasets(
def rebuild_with_new_data(
self,
interactions_df: pd.DataFrame,
user_features_df: tp.Optional[pd.DataFrame] = None,
Expand Down
4 changes: 2 additions & 2 deletions rectools/models/implicit_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def _fit_partial(self, dataset: Dataset) -> None:
# Implicit ALS assumes to fit first on all data.
if not self.is_fitted:
raise NotFittedError(self.__class__.__name__)

Check warning on line 96 in rectools/models/implicit_als.py

View check run for this annotation

Codecov / codecov/patch

rectools/models/implicit_als.py#L96

Added line #L96 was not covered by tests
users = dataset.get_hot_users()
items = dataset.get_hot_items()
users = dataset.get_hot_users_internal()
items = dataset.get_hot_items_internal()

ui_csr = dataset.get_user_item_matrix(
include_weights=True, include_warm_users=True, include_warm_items=True
Expand Down
113 changes: 112 additions & 1 deletion tests/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ def assert_dataset_equal_to_expected(
expected_item_features: tp.Optional[Features],
expected_user_id_map: tp.Optional[IdMap] = None,
expected_item_id_map: tp.Optional[IdMap] = None,
expected_interactions: tp.Optional[Interactions] = None,
) -> None:
expected_user_id_map = expected_user_id_map or self.expected_user_id_map
expected_item_id_map = expected_item_id_map or self.expected_item_id_map
expected_interactions = expected_interactions or self.expected_interactions

assert_id_map_equal(actual.user_id_map, expected_user_id_map)
assert_id_map_equal(actual.item_id_map, expected_item_id_map)
assert_interactions_set_equal(actual.interactions, self.expected_interactions)
assert_interactions_set_equal(actual.interactions, expected_interactions)
assert_feature_set_equal(actual.user_features, expected_user_features)
assert_feature_set_equal(actual.item_features, expected_item_features)

Expand Down Expand Up @@ -284,3 +286,112 @@ def test_get_raw_interactions(self, include_weight: bool, include_datetime: bool
if not include_datetime:
expected.drop(columns=Columns.Datetime, inplace=True)
pd.testing.assert_frame_equal(actual, expected)

def test_rebuild_with_new_data_without_feature(self) -> None:
dataset = Dataset.construct(self.interactions_df)
new_interactions_df = pd.DataFrame(
[
["u2", "i3", 5, "2021-09-03"],
["u4", "i1", 3, "2021-09-09"],
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
)
new_dataset = dataset.rebuild_with_new_data(new_interactions_df)
expected_user_id_map = IdMap.from_values(["u1", "u2", "u3", "u4"])
expected_item_id_map = IdMap.from_values(["i1", "i2", "i5", "i3"])
expected_interactions = Interactions(
pd.DataFrame(
[
[1, 3, 5.0, datetime(2021, 9, 3)],
[3, 0, 3.0, datetime(2021, 9, 9)],
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
),
)

self.assert_dataset_equal_to_expected(new_dataset, None, None, expected_item_id_map=expected_item_id_map, expected_user_id_map=expected_user_id_map, expected_interactions=expected_interactions)

def test_rebuild_with_new_data_with_feature(self) -> None:
user_features_df = pd.DataFrame(
[
["u1", 77, 99],
["u2", 33, 55],
["u3", 22, 11],
["u4", 22, 11], # Warm user
],
columns=[Columns.User, "f1", "f2"],
)
item_features_df = pd.DataFrame(
[
["i2", "f1", 3],
["i2", "f2", 20],
["i5", "f2", 20],
["i5", "f2", 30],
["i7", "f2", 70], # Warm item
],
columns=[Columns.Item, "feature", "value"],
)
dataset = Dataset.construct(
self.interactions_df,
user_features_df=user_features_df,
make_dense_user_features=True,
item_features_df=item_features_df,
cat_item_features=["f2"],
)
new_interactions_df = pd.DataFrame(
[
["u2", "i8", 5, "2021-09-03"], # Warm item in interactions
["u5", "i1", 3, "2021-09-09"], # Warm user in interactions
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
)
new_user_features_df = pd.DataFrame(
[
["u1", 77, 99],
["u2", 33, 55],
["u3", 22, 11],
["u4", 22, 11], # Warm user in old data
["u5", 55, 22], # Warm user in new data
],
columns=[Columns.User, "f1", "f2"],
)
new_item_features_df = pd.DataFrame(
[
["i2", "f1", 3],
["i2", "f2", 20],
["i5", "f2", 20],
["i5", "f2", 30],
["i7", "f2", 70], # Warm item in old data
["i8", "f2", 70], # Warm item in new data
],
columns=[Columns.Item, "feature", "value"],
)
new_dataset = dataset.rebuild_with_new_data(
new_interactions_df,
user_features_df=new_user_features_df,
make_dense_user_features=True,
item_features_df=new_item_features_df,
cat_item_features=["f2"],
)
expected_user_id_map = IdMap.from_values(["u1", "u2", "u3", "u4", "u5"])
expected_item_id_map = IdMap.from_values(["i1", "i2", "i5", "i7", "i8"])
expected_interactions = Interactions(
pd.DataFrame(
[
[1, 4, 5.0, datetime(2021, 9, 3)],
[4, 0, 3.0, datetime(2021, 9, 9)],
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
),
)

expected_user_features = DenseFeatures.from_dataframe(new_user_features_df, expected_user_id_map, Columns.User)
expected_item_features = SparseFeatures.from_flatten(
new_item_features_df,
expected_item_id_map,
["f2"],
id_col=Columns.Item,
)
self.assert_dataset_equal_to_expected(
new_dataset, expected_user_features, expected_item_features, expected_user_id_map, expected_item_id_map, expected_interactions
)
2 changes: 1 addition & 1 deletion tests/models/test_implicit_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_fit_partial(self, use_gpu: bool, dataset: Dataset) -> None:
new_interactions = pd.DataFrame(data, columns=Columns.UserItem)
new_interactions[Columns.Weight] = 1
new_interactions[Columns.Datetime] = "2021-09-10"
new_dataset = dataset.construct_new_datasets(new_interactions)
new_dataset = dataset.rebuild_with_new_data(new_interactions)
model.fit_partial(new_dataset)
actual = model.recommend(
users=[10, 150], # old user, new user
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_lightfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_fit_partial(self, dataset: Dataset) -> None:
new_interactions = pd.DataFrame(data, columns=Columns.UserItem)
new_interactions[Columns.Weight] = 1
new_interactions[Columns.Datetime] = "2021-09-10"
new_dataset = dataset.construct_new_datasets(interactions_df=new_interactions)
new_dataset = dataset.rebuild_with_new_data(interactions_df=new_interactions)
model.fit_partial(new_dataset)
actual = model.recommend(
users=np.array([10, 150]), # new user
Expand Down

0 comments on commit c3d1625

Please sign in to comment.