Skip to content

Commit

Permalink
Fix/warp-kos loss for LightFMWrapperModel (#175)
Browse files Browse the repository at this point in the history
- Fixed `NotImplementedError` for LightFM with warp-kos loss
  • Loading branch information
chezou authored Aug 13, 2024
1 parent 0f4034c commit f73d054
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- `display()` method in `MetricsApp` ([#169](https://github.com/MobileTeleSystems/RecTools/pull/169))

### Fixed
- Allow warp-kos loss for LightFMWrapperModel ([#175](https://github.com/MobileTeleSystems/RecTools/pull/175))

## [0.7.0] - 29.07.2024

### Added
Expand Down
3 changes: 2 additions & 1 deletion rectools/models/lightfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore
ui_coo = dataset.get_user_item_matrix(include_weights=True).tocoo(copy=False)
user_features = self._prepare_features(dataset.get_hot_user_features(), dataset.n_hot_users)
item_features = self._prepare_features(dataset.get_hot_item_features(), dataset.n_hot_items)
sample_weight = None if self._model.loss == "warp-kos" else ui_coo

self.model.fit(
ui_coo,
user_features=user_features,
item_features=item_features,
sample_weight=ui_coo,
sample_weight=sample_weight,
epochs=self.n_epochs,
num_threads=self.n_threads,
verbose=self.verbose > 0,
Expand Down
10 changes: 10 additions & 0 deletions tests/models/test_lightfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,16 @@ def test_with_weights(self, interactions_df: pd.DataFrame) -> None:
actual,
)

def test_with_warp_kos(self, dataset: Dataset) -> None:
base_model = DeterministicLightFM(no_components=2, loss="warp-kos")
try:
LightFMWrapperModel(model=base_model, epochs=10).fit(dataset)
except NotImplementedError:
pytest.fail("Should not raise NotImplementedError")
except ValueError:
# LightFM raises ValueError with the dataset
pass

def test_get_vectors(self, dataset_with_features: Dataset) -> None:
base_model = LightFM(no_components=2, loss="logistic")
model = LightFMWrapperModel(model=base_model).fit(dataset_with_features)
Expand Down

0 comments on commit f73d054

Please sign in to comment.