Skip to content

Commit

Permalink
Added should_dump_supplementary_log_files option (#291)
Browse files Browse the repository at this point in the history
* add should_dump_supplementary_log_files option

* add docs
  • Loading branch information
mski-iksm authored Jul 5, 2022
1 parent 40299e8 commit c5376e6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
6 changes: 6 additions & 0 deletions docs/task_on_kart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,9 @@ TaskOnKart.fail_on_empty_dump
Please refer to :doc:`for_pandas`.


TaskOnKart.should_dump_supplementary_log_files
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. Default is True.

Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.
24 changes: 17 additions & 7 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class TaskOnKart(luigi.Task):
significant=False)

cache_unique_id: bool = ExplicitBoolParameter(default=True, description='Cache unique id during runtime', significant=False)
should_dump_supplementary_log_files: bool = ExplicitBoolParameter(
default=True,
description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \
Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.',
significant=False)

def __init__(self, *args, **kwargs):
self._add_configuration(kwargs, 'TaskOnKart')
Expand Down Expand Up @@ -353,7 +358,8 @@ def get_task_log(self) -> Dict:
@luigi.Task.event_handler(luigi.Event.SUCCESS)
def _dump_task_log(self):
self.task_log['file_path'] = [target.path() for target in luigi.task.flatten(self.output())]
self.dump(self.task_log, self._get_task_log_target())
if self.should_dump_supplementary_log_files:
self.dump(self.task_log, self._get_task_log_target())

def _get_task_params_target(self):
return self.make_target(f'log/task_params/{type(self).__name__}.pkl')
Expand All @@ -366,9 +372,10 @@ def get_task_params(self) -> Dict:

@luigi.Task.event_handler(luigi.Event.START)
def _set_random_seed(self):
random_seed = self._get_random_seed()
seed_methods = self.try_set_seed(self.fix_random_seed_methods, random_seed)
self.dump({'seed': random_seed, 'seed_methods': seed_methods}, self._get_random_seeds_target())
if self.should_dump_supplementary_log_files:
random_seed = self._get_random_seed()
seed_methods = self.try_set_seed(self.fix_random_seed_methods, random_seed)
self.dump({'seed': random_seed, 'seed_methods': seed_methods}, self._get_random_seeds_target())

def _get_random_seeds_target(self):
return self.make_target(f'log/random_seed/{type(self).__name__}.pkl')
Expand Down Expand Up @@ -398,7 +405,8 @@ def _get_random_seed(self):

@luigi.Task.event_handler(luigi.Event.START)
def _dump_task_params(self):
self.dump(self.to_str_params(only_significant=True), self._get_task_params_target())
if self.should_dump_supplementary_log_files:
self.dump(self.to_str_params(only_significant=True), self._get_task_params_target())

def _get_processing_time_target(self):
return self.make_target(f'log/processing_time/{type(self).__name__}.pkl')
Expand All @@ -411,7 +419,8 @@ def get_processing_time(self) -> str:

@luigi.Task.event_handler(luigi.Event.PROCESSING_TIME)
def _dump_processing_time(self, processing_time):
self.dump(processing_time, self._get_processing_time_target())
if self.should_dump_supplementary_log_files:
self.dump(processing_time, self._get_processing_time_target())

@classmethod
def restore(cls, unique_id):
Expand All @@ -424,7 +433,8 @@ def _log_unique_id(self, exception):

@luigi.Task.event_handler(luigi.Event.START)
def _dump_module_versions(self):
self.dump(self._get_module_versions(), self._get_module_versions_target())
if self.should_dump_supplementary_log_files:
self.dump(self._get_module_versions(), self._get_module_versions_target())

def _get_module_versions_target(self):
return self.make_target(f'log/module_versions/{type(self).__name__}.txt')
Expand Down

0 comments on commit c5376e6

Please sign in to comment.