Skip to content

Commit

Permalink
Load DataFrame cache with backward compatibility (#381)
Browse files Browse the repository at this point in the history
* test: add test for ReadableS3File

* refactor: use dill.load for ReadableS3File to unify the interface with other target

* feat: load dill with pandas backward compatibility

---------

Co-authored-by: Shogo Ujiie <[email protected]>
  • Loading branch information
ujiuji1259 and Shogo Ujiie authored Jul 11, 2024
1 parent 2c0023b commit e26b4ff
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
6 changes: 4 additions & 2 deletions gokart/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from luigi.format import TextFormat

from gokart.object_storage import ObjectStorage
from gokart.utils import load_dill_with_pandas_backward_compatibility

logger = getLogger(__name__)

Expand Down Expand Up @@ -82,8 +83,9 @@ def format(self):

def load(self, file):
if not ObjectStorage.is_buffered_reader(file):
return dill.loads(file.read())
return dill.load(_ChunkedLargeFileReader(file))
# we cannot use dill.load(file) because ReadableS3File does not have 'readline' method
return load_dill_with_pandas_backward_compatibility(BytesIO(file.read()))
return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file))

def dump(self, obj, file):
self._write(dill.dumps(obj, protocol=4), file)
Expand Down
21 changes: 20 additions & 1 deletion gokart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

import os
import sys
from typing import Iterable, TypeVar
from typing import Any, Iterable, Protocol, TypeVar, Union

import dill
import luigi
import pandas as pd


class FileLike(Protocol):
def read(self, n: int) -> bytes: ...

def readline(self) -> bytes: ...


def add_config(file_path: str):
Expand Down Expand Up @@ -58,3 +66,14 @@ def flatten(targets: FlattenableItems[T]) -> list[T]:
for result in targets:
flat += flatten(result)
return flat


def load_dill_with_pandas_backward_compatibility(file: FileLike) -> Any:
"""Load binary dumped by dill with pandas backward compatibility.
pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle.
It is unclear whether all objects dumped by dill can be loaded by pd.read_pickle, we use dill.load as a fallback.
"""
try:
return pd.read_pickle(file)
except Exception:
return dill.load(file)
21 changes: 21 additions & 0 deletions test/test_file_processor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import tempfile
import unittest
from typing import Callable

import boto3
import pandas as pd
from luigi import LocalTarget
from moto import mock_aws

from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, PickleFileProcessor
from gokart.object_storage import ObjectStorage


class TestCsvFileProcessor(unittest.TestCase):
Expand Down Expand Up @@ -115,6 +119,23 @@ def run(self) -> int:

self.assertEqual(loaded.run(), obj.run())

@mock_aws
def test_dump_and_load_with_readables3file(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
file_path = os.path.join('s3://test/', 'test.pkl')

var = 'abc'
processor = PickleFileProcessor()

target = ObjectStorage.get_object_storage_target(file_path, processor.format())
with target.open('w') as f:
processor.dump(var, f)
with target.open('r') as f:
loaded = processor.load(f)

self.assertEqual(loaded, var)


class TestFeatherFileProcessor(unittest.TestCase):
def test_feather_should_return_same_dataframe(self):
Expand Down

0 comments on commit e26b4ff

Please sign in to comment.