Skip to content

Commit

Permalink
Add parameter alias to inherits_config_params (#193)
Browse files Browse the repository at this point in the history
* add parameter alias

* fix isort

* add docs

* add docstring

* fix typo

* fix typo

* change parameter name
  • Loading branch information
mski-iksm authored Apr 23, 2021
1 parent 78d8431 commit 5eb59d8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
16 changes: 16 additions & 0 deletions docs/setting_task_parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,19 @@ This is useful when multiple tasks has the same parameter, since parameter setti

Note that parameters which exist in both ``MasterConfig`` and ``SomeTask`` will be inherited.
In the above example, ``param2`` will not be available in ``SomeTask``, since ``SomeTask`` does not have ``param2`` parameter.

.. code:: python
class MasterConfig(luigi.Config):
param: str = luigi.Parameter()
param2: str = luigi.Parameter()
@inherits_config_params(MasterConfig, parameter_alias={'param2': 'param3'})
class SomeTask(gokart.TaskOnKart):
param3: str = luigi.Parameter()
You may also set a parameter name alias by setting ``parameter_alias``.
``parameter_alias`` must be a dictionary of inheriting task's parameter name as keys and decorating task's parameter names as values.

In the above example, ``SomeTask.param3`` will be set to same value as ``MasterConfig.param2``.
25 changes: 18 additions & 7 deletions gokart/config_params.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
from typing import Dict, Optional

import luigi

import gokart


class inherits_config_params:
def __init__(self, config_class: luigi.Config):
self.config_class: luigi.Config = config_class
def __init__(self, config_class: luigi.Config, parameter_alias: Optional[Dict[str, str]] = None):
"""
Decorates task to inherit parameter value of `config_class`.
def __call__(self, task: gokart.TaskOnKart):
config_class = self.config_class
* config_class: Inherit parameter value of this task to decorated task. Only parameter values exist in both tasks are inherited.
* parameter_alias: Dictionary to map paramter names between config_class task and decorated task.
key: config_class's parameter name. value: decorated task's parameter name.
"""

self._config_class: luigi.Config = config_class
self._parameter_alias: Dict[str, str] = parameter_alias if parameter_alias is not None else {}

def __call__(self, task: gokart.TaskOnKart):
# wrap task to prevent task name from being changed
@luigi.task._task_wraps(task)
class Wrapped(task):
@classmethod
def get_param_values(cls, params, args, kwargs):
for k, v in config_class().param_kwargs.items():
if hasattr(cls, k) and k not in kwargs:
kwargs[k] = v
for param_key, param_value in self._config_class().param_kwargs.items():
task_param_key = self._parameter_alias.get(param_key, param_key)

if hasattr(cls, task_param_key) and task_param_key not in kwargs:
kwargs[task_param_key] = param_value
return super(Wrapped, cls).get_param_values(params, args, kwargs)

return Wrapped
10 changes: 10 additions & 0 deletions test/test_config_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ class Inherited(gokart.TaskOnKart):
param_b = luigi.Parameter(default='overrided')


@inherits_config_params(ConfigClass, parameter_alias={'param_a': 'param_d'})
class Inherited2(gokart.TaskOnKart):
param_c = luigi.Parameter()
param_d = luigi.Parameter()


class ChildTask(Inherited):
pass

Expand Down Expand Up @@ -57,6 +63,10 @@ def test_inherited_params(self):
with self.assertRaises(AttributeError):
in_parse(['Inherited'], lambda task: task.param_c)

# test parameter name alias
in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_c, 'config c'))
in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_d, 'config a'))

def test_child_task(self):
in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_a, 'config a'))
in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_b, 'config b'))
Expand Down

0 comments on commit 5eb59d8

Please sign in to comment.