Skip to content

Commit

Permalink
Use cache in make task info (#283)
Browse files Browse the repository at this point in the history
* add lru_cache in task_info

* add cache

* add test

* fix

* fix yapf
  • Loading branch information
ujiuji1259 authored May 6, 2022
1 parent 33b44b5 commit 207666f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
30 changes: 19 additions & 11 deletions gokart/tree/task_info_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,20 @@ def _make_requires_info(requires):
raise TypeError(f'`requires` has unexpected type {type(requires)}. Must be `TaskOnKart`, `Iterarble[TaskOnKart]`, or `Dict[str, TaskOnKart]`')


def make_task_info_tree(task: TaskOnKart, ignore_task_names: Optional[List[str]] = None) -> TaskInfo:
def make_task_info_tree(task: TaskOnKart, ignore_task_names: Optional[List[str]] = None, cache: Optional[Dict[str, TaskInfo]] = None) -> TaskInfo:
with warnings.catch_warnings():
warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete() method')
is_task_complete = task.complete()

name = task.__class__.__name__
unique_id = task.make_unique_id()
output_paths = [t.path() for t in luigi.task.flatten(task.output())]

cache = {} if cache is None else cache
cache_id = f'{name}_{unique_id}_{is_task_complete}'
if cache_id in cache:
return cache[cache_id]

params = task.get_info(only_significant=True)
processing_time = task.get_processing_time()
if type(processing_time) == float:
Expand All @@ -76,16 +82,18 @@ def make_task_info_tree(task: TaskOnKart, ignore_task_names: Optional[List[str]]
children_task_infos: List[TaskInfo] = []
for child in children:
if ignore_task_names is None or child.__class__.__name__ not in ignore_task_names:
children_task_infos.append(make_task_info_tree(child, ignore_task_names=ignore_task_names))
return TaskInfo(name=name,
unique_id=unique_id,
output_paths=output_paths,
params=params,
processing_time=processing_time,
is_complete=is_complete,
task_log=task_log,
requires=requires,
children_task_infos=children_task_infos)
children_task_infos.append(make_task_info_tree(child, ignore_task_names=ignore_task_names, cache=cache))
task_info = TaskInfo(name=name,
unique_id=unique_id,
output_paths=output_paths,
params=params,
processing_time=processing_time,
is_complete=is_complete,
task_log=task_log,
requires=requires,
children_task_infos=children_task_infos)
cache[cache_id] = task_info
return task_info


def make_tree_info(task_info: TaskInfo, indent: str, last: bool, details: bool, abbr: bool, visited_tasks: Set[str]):
Expand Down
13 changes: 12 additions & 1 deletion test/tree/test_task_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from luigi.mock import MockFileSystem, MockTarget

import gokart
from gokart.tree.task_info import dump_task_info_table, dump_task_info_tree, make_task_info_as_tree_str
from gokart.tree.task_info import dump_task_info_table, dump_task_info_tree, make_task_info_as_tree_str, make_task_info_tree


class _SubTask(gokart.TaskOnKart):
Expand Down Expand Up @@ -125,6 +125,17 @@ def test_make_tree_info_not_compress_ignore_task(self):
└─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)

@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_with_cache(self):
task = _DoubleLoadSubTask(
sub1=_Task(param=1, sub=_SubTask(param=2)),
sub2=_Task(param=1, sub=_SubTask(param=2)),
)

# check child task_info is the same object
tree = make_task_info_tree(task)
self.assertTrue(tree.children_task_infos[0] is tree.children_task_infos[1])


class _TaskInfoExampleTaskA(gokart.TaskOnKart):
task_namespace = __name__
Expand Down

0 comments on commit 207666f

Please sign in to comment.