Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Dump with Type Check #360

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions gokart/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime
from glob import glob
from logging import getLogger
from typing import Any, Optional
from typing import Any, Callable, Optional

import luigi
import numpy as np
Expand Down Expand Up @@ -78,10 +78,12 @@ def __init__(
target: luigi.target.FileSystemTarget,
processor: FileProcessor,
task_lock_params: TaskLockParams,
validator: Callable[[Any], bool] = lambda x: True,
) -> None:
self._target = target
self._processor = processor
self._task_lock_params = task_lock_params
self._validator = validator

def _exists(self) -> bool:
return self._target.exists()
Expand All @@ -91,9 +93,14 @@ def _get_task_lock_params(self) -> TaskLockParams:

def _load(self) -> Any:
with self._target.open('r') as f:
return self._processor.load(f)
obj = self._processor.load(f)
self._validator(obj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ujiuji1259 You can use retrun value of validator and let this process fail if return value is False


return obj

def _dump(self, obj) -> None:
self._validator(obj)

with self._target.open('w') as f:
self._processor.dump(obj, f)

Expand Down Expand Up @@ -216,12 +223,13 @@ def make_target(
processor: Optional[FileProcessor] = None,
task_lock_params: Optional[TaskLockParams] = None,
store_index_in_feather: bool = True,
validator: Callable[[Any], bool] = lambda x: True,
) -> TargetOnKart:
_task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id)
file_path = _make_file_path(file_path, unique_id)
processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather)
file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather)
return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params)
return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, validator=validator)


def make_model_target(
Expand Down
15 changes: 13 additions & 2 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,13 @@ def clone(self, cls=None, **kwargs):

return cls(**new_k)

def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart:
def make_target(
self,
relative_file_path: Optional[str] = None,
use_unique_id: bool = True,
processor: Optional[FileProcessor] = None,
validator: Callable[[Any], bool] = lambda x: True,
) -> TargetOnKart:
formatted_relative_file_path = (
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl')
)
Expand All @@ -210,7 +216,12 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b
)

return gokart.target.make_target(
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
file_path=file_path,
unique_id=unique_id,
processor=processor,
task_lock_params=task_lock_params,
store_index_in_feather=self.store_index_in_feather,
validator=validator,
)

def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
Expand Down
15 changes: 2 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 42 additions & 2 deletions test/test_target.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import io
import os
import shutil
import tempfile
import unittest
from datetime import datetime
from unittest.mock import patch

import boto3
import luigi
import numpy as np
import pandas as pd
import pandera as pa
from matplotlib import pyplot
from moto import mock_aws

from gokart.file_processor import _ChunkedLargeFileReader
from gokart.target import make_model_target, make_target
from gokart.file_processor import _ChunkedLargeFileReader, make_file_processor
from gokart.target import SingleFileTarget, make_model_target, make_target


def _get_temporary_directory():
Expand Down Expand Up @@ -280,5 +283,42 @@ def test_model_target_on_s3(self):
self.assertEqual(loaded, obj)


class SingleFileTargetTest(unittest.TestCase):
def test_typed_target(self):
def validate_dataframe(x):
return isinstance(x, pd.DataFrame)

test_case = pd.DataFrame(dict(a=[1, 2]))

with tempfile.TemporaryDirectory() as temp_dir:
_task_lock_params = None
file_path = os.path.join(temp_dir, 'test.pkl')
processor = make_file_processor(file_path, store_index_in_feather=False)
file_system_target = luigi.LocalTarget(file_path, format=processor.format())
file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, validator=validate_dataframe)

file_target.dump(test_case)
dumped_data = file_target.load()
self.assertIsInstance(dumped_data, self.DummyDataFrameSchema)

def test_invalid_typed_target(self):
def validate_int(x):
return isinstance(x, int)

test_case = pd.DataFrame(dict(a=['1', '2']))

with tempfile.TemporaryDirectory() as temp_dir:
_task_lock_params = None
file_path = os.path.join(temp_dir, 'test.csv')
processor = make_file_processor(file_path, store_index_in_feather=False)
file_system_target = luigi.LocalTarget(file_path, format=processor.format())
file_target = SingleFileTarget(
target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=validate_int
)

with self.assertRaises(pa.errors.SchemaError):
file_target.dump(test_case)


if __name__ == '__main__':
unittest.main()
Loading