Skip to content

Commit

Permalink
Use dill instead of pickle for processing .pkl files (#354)
Browse files Browse the repository at this point in the history
* use dill instead of pickle for file processor

* add unittest for pickle processor

* relock

* fix unittest

* fmt

* rename dill

* relock

* relock with poetry 1.8.2

* relock

* lock

* fmt

* relock

---------

Co-authored-by: yutaro-oguri <[email protected]>
  • Loading branch information
maronuu and yutaro-oguri authored Apr 23, 2024
1 parent 0a26811 commit 7a00e34
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 7 deletions.
8 changes: 4 additions & 4 deletions gokart/file_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import pickle
import xml.etree.ElementTree as ET
from abc import abstractmethod
from io import BytesIO
from logging import getLogger

import dill
import luigi
import luigi.contrib.s3
import luigi.format
Expand Down Expand Up @@ -82,11 +82,11 @@ def format(self):

def load(self, file):
if not ObjectStorage.is_buffered_reader(file):
return pickle.loads(file.read())
return pickle.load(_ChunkedLargeFileReader(file))
return dill.loads(file.read())
return dill.load(_ChunkedLargeFileReader(file))

def dump(self, obj, file):
self._write(pickle.dumps(obj, protocol=4), file)
self._write(dill.dumps(obj, protocol=4), file)

@staticmethod
def _write(buffer, file):
Expand Down
18 changes: 16 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ uritemplate = "*"
google-api-python-client = "*"
APScheduler = "*"
redis = "*"
dill = "*"
backoff = "^2.2.1"

[tool.poetry.group.dev.dependencies]
Expand Down
51 changes: 50 additions & 1 deletion test/test_file_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import tempfile
import unittest
from typing import Callable

import pandas as pd
from luigi import LocalTarget

from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor
from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, PickleFileProcessor


class TestCsvFileProcessor(unittest.TestCase):
Expand Down Expand Up @@ -67,6 +68,54 @@ def test_load_csv_with_cp932(self):
pd.testing.assert_frame_equal(df, loaded_df)


class TestPickleFileProcessor(unittest.TestCase):
def test_dump_and_load_normal_obj(self):
var = 'abc'
processor = PickleFileProcessor()

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.pkl'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(var, f)
with local_target.open('r') as f:
loaded = processor.load(f)

self.assertEqual(loaded, var)

def test_dump_and_load_class(self):
import functools

def plus1(func: Callable[[], int]) -> Callable[[], int]:
@functools.wraps(func)
def wrapped() -> int:
ret = func()
return ret + 1

return wrapped

class A:
run: Callable[[], int]

def __init__(self) -> None:
self.run = plus1(self.run)

def run(self) -> int:
return 1

obj = A()
processor = PickleFileProcessor()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.pkl'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(obj, f)
with local_target.open('r') as f:
loaded = processor.load(f)

self.assertEqual(loaded.run(), obj.run())


class TestFeatherFileProcessor(unittest.TestCase):
def test_feather_should_return_same_dataframe(self):
df = pd.DataFrame({'a': [1]})
Expand Down

0 comments on commit 7a00e34

Please sign in to comment.