diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 56510b7f..c52fa105 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -1,10 +1,10 @@ import os -import pickle import xml.etree.ElementTree as ET from abc import abstractmethod from io import BytesIO from logging import getLogger +import dill import luigi import luigi.contrib.s3 import luigi.format @@ -82,11 +82,11 @@ def format(self): def load(self, file): if not ObjectStorage.is_buffered_reader(file): - return pickle.loads(file.read()) - return pickle.load(_ChunkedLargeFileReader(file)) + return dill.loads(file.read()) + return dill.load(_ChunkedLargeFileReader(file)) def dump(self, obj, file): - self._write(pickle.dumps(obj, protocol=4), file) + self._write(dill.dumps(obj, protocol=4), file) @staticmethod def _write(buffer, file): diff --git a/poetry.lock b/poetry.lock index eb1dc554..9a4e3a73 100644 --- a/poetry.lock +++ b/poetry.lock @@ -494,6 +494,21 @@ files = [ docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] tests = ["pytest", "pytest-cov", "pytest-xdist"] +[[package]] +name = "dill" +version = "0.3.8" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "distlib" version = "0.3.8" @@ -1694,7 +1709,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2190,4 +2204,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<4" -content-hash = "7f41cde6c27e1b7ab1c8b75c99d5c0665cd4b3bbdcb62e1470d594213a88f70d" +content-hash = "ff22f83a344e27ef343493c93d88cde98170f6f03e525b47a6d635fc3a076581" diff --git a/pyproject.toml b/pyproject.toml index 1b7db4ec..1fcd37b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ uritemplate = "*" google-api-python-client = "*" APScheduler = "*" redis = "*" +dill = "*" backoff = "^2.2.1" [tool.poetry.group.dev.dependencies] diff --git a/test/test_file_processor.py b/test/test_file_processor.py index b2aad465..f5f1640d 100644 --- a/test/test_file_processor.py +++ b/test/test_file_processor.py @@ -1,10 +1,11 @@ import tempfile import unittest +from typing import Callable import pandas as pd from luigi import LocalTarget -from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor +from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, PickleFileProcessor class TestCsvFileProcessor(unittest.TestCase): @@ -67,6 +68,54 @@ def test_load_csv_with_cp932(self): pd.testing.assert_frame_equal(df, loaded_df) +class TestPickleFileProcessor(unittest.TestCase): + def test_dump_and_load_normal_obj(self): + var = 'abc' + processor = PickleFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.pkl' + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(var, f) + with local_target.open('r') as f: + loaded = processor.load(f) + + self.assertEqual(loaded, var) + + def test_dump_and_load_class(self): + import functools + + def plus1(func: Callable[[], int]) -> Callable[[], int]: + @functools.wraps(func) + def wrapped() -> int: + ret = func() + return ret + 1 + + return wrapped + + class A: + run: Callable[[], int] + + def __init__(self) -> None: + self.run = plus1(self.run) + + def run(self) -> int: + return 1 + + obj = A() + processor = PickleFileProcessor() + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.pkl' + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(obj, f) + with local_target.open('r') as f: + loaded = processor.load(f) + + self.assertEqual(loaded.run(), obj.run()) + + class TestFeatherFileProcessor(unittest.TestCase): def test_feather_should_return_same_dataframe(self): df = pd.DataFrame({'a': [1]})