From fd260776d54487652f3ce5279aed19d8edc1433f Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Tue, 19 Jul 2022 11:07:22 +0800 Subject: [PATCH 01/13] minor --- qlib/rl/utils/log.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 409a48a768..0faf952f29 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -510,6 +510,34 @@ def on_env_all_done(self) -> None: class PickleWriter(LogWriter): """Dump logs to pickle files.""" + SUPPORTED_TYPES = Any + + all_records: list[dict[str, Any]] + + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC): + super().__init__(loglevel) + self.output_dir = output_dir + self.output_dir.mkdir(exist_ok=True, parents=True) + + def clear(self): + super().claer() + self.all_records = [] + + def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + # FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup + episode_wise_contents: dict[str, list] = defaultdict(list) + + for step_contents in contents: + for name, value in step_contents.items(): + if isinstance(value, self.SUPPORTED_TYPES): + logs[name].append(value) + + self.all_records.append(logs) + + def on_env_all_done(self) -> None: + # FIXME: this is temporary + pd.DataFrame.from_records(self.all_records).to_pickle(self.output_dir / "result.pkl") + class TensorboardWriter(LogWriter): """Write logs to event files that can be visualized with tensorboard.""" From a99910e21e08543aba5d29cf686b038863bb00cd Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Mon, 29 Aug 2022 10:50:28 +0800 Subject: [PATCH 02/13] Add opd interpreter --- qlib/rl/data/pickle_styled.py | 11 +++ qlib/rl/order_execution/interpreter.py | 103 +++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 43fe9dd5ad..c020a139b1 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -258,6 +258,17 @@ def load_intraday_processed_data( return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) +class TeacherActionData: + teacher_action: pd.DataFrame + + def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> None: # type: ignore + self.teacher_action = pd.read_pickle(teacher_action_file).loc[pd.IndexSlice[stock_id, date]] # type: ignore + + +def load_teacher_action_data(teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> TeacherActionData: # type: ignore + return TeacherActionData(teacher_action_file, stock_id, date) + + def load_orders( order_path: Path, start_time: pd.Timestamp = None, diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 089fc553cf..a7d42dc862 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -52,6 +52,10 @@ class FullHistoryObs(TypedDict): position_history: Any +class OPDObs(FullHistoryObs): + teacher_action: Any + + class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): """The observation of all the history, including today (until this moment), and yesterday. @@ -132,6 +136,105 @@ def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame: return arr +class OracleObsInterpreter(FullHistoryStateInterpreter): + def interpret(self, state: SAOEState) -> FullHistoryObs: + processed = pickle_styled.load_intraday_processed_data( + self.data_dir, + state.order.stock_id, + pd.Timestamp(state.order.start_time.date()), + self.data_dim, + state.ticks_index, + ) + + position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) + position_history[0] = state.order.amount + position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() + + assert self.env is not None + + return cast( + FullHistoryObs, + canonicalize( + { + "data_processed": processed.today, + "data_processed_prev": processed.yesterday, + "acquiring": state.order.direction == state.order.BUY, + "cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1), + "cur_step": min(self.env.status["cur_step"], self.max_step - 1), + "num_step": self.max_step, + "target": state.order.amount, + "position": state.position, + "position_history": position_history[: self.max_step], + }, + ), + ) + + +class OPDObsInterpreter(FullHistoryStateInterpreter): + def __init__( + self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int, teacher_action_file: Path + ) -> None: + super().__init__(data_dir, max_step, data_ticks, data_dim) + self.teacher_action_file = teacher_action_file + + def interpret(self, state: SAOEState) -> OPDObs: + processed = pickle_styled.load_intraday_processed_data( + self.data_dir, + state.order.stock_id, + pd.Timestamp(state.order.start_time.date()), + self.data_dim, + state.ticks_index, + ) + + position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) + position_history[0] = state.order.amount + position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() + + teacher_action = pickle_styled.load_teacher_action_data( + self.teacher_action_file, state.order.stock_id, pd.Timestamp(state.order.start_time.date()) + ).teacher_action + + assert self.env is not None + + # The min, slice here are to make sure that indices fit into the range, + # even after the final step of the simulator (in the done step), + # to make network in policy happy. + return cast( + OPDObs, + canonicalize( + { + "data_processed": self._mask_future_info(processed.today, state.cur_time), + "data_processed_prev": processed.yesterday, + "acquiring": state.order.direction == state.order.BUY, + "cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1), + "cur_step": min(self.env.status["cur_step"], self.max_step - 1), + "num_step": self.max_step, + "target": state.order.amount, + "position": state.position, + "position_history": position_history[: self.max_step], + "teacher_action": teacher_action.values[self.env.status["cur_step"]], + }, + ), + ) + + @property + def observation_space(self) -> spaces.Dict: + space = { + "data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "acquiring": spaces.Discrete(2), + "cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32), + "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), + # TODO: support arbitrary length index + "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32), + "target": spaces.Box(-EPS, np.inf, shape=()), + "position": spaces.Box(-EPS, np.inf, shape=()), + "position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)), + "teacher_action": spaces.Box(-EPS, np.inf, shape=()), + } + return spaces.Dict(space) + + class CurrentStateObs(TypedDict): acquiring: bool cur_step: int From c50ba2ff6b0e29ed38b18cf9b085c4a482125d6d Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Mon, 26 Sep 2022 13:17:31 +0800 Subject: [PATCH 03/13] log --- qlib/rl/utils/log.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index bf50762d9f..dc12c10396 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -543,6 +543,33 @@ def on_env_all_done(self) -> None: pd.DataFrame.from_records(self.all_records).to_pickle(self.output_dir / "result.pkl") +class ActionWriter(LogWriter): + """Dump policy actions to pickle files""" + + all_records: dict[str, list[Any]] + + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + super().__init__(loglevel) + self.output_dir = output_dir + self.output_dir.mkdir(exist_ok=True) + + def clear(self) -> None: + super().clear() + self.all_records = defaultdict(list) + + def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: + + for step_index, step_contents in enumerate(contents): + for name, value in step_contents.items(): + if name in ["policy_act", "stock_id", "datetime"]: + self.all_records[name].append(value) + self.all_records["step"].append(step_index) + + def on_env_all_done(self) -> None: + # FIXME: this is temporary + pd.DataFrame.from_dict(self.all_records).to_pickle(self.output_dir / "action.pkl") + + class TensorboardWriter(LogWriter): """Write logs to event files that can be visualized with tensorboard.""" From dbba68a311b4444f70b8fd7ceb9dc5b26ee6c327 Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Sun, 12 Feb 2023 14:58:53 +0800 Subject: [PATCH 04/13] update pickle_styled --- qlib/rl/data/pickle_styled.py | 89 +++++++++++++++-------------------- 1 file changed, 39 insertions(+), 50 deletions(-) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index c020a139b1..981d313837 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -19,7 +19,6 @@ from __future__ import annotations -from abc import abstractmethod from functools import lru_cache from pathlib import Path from typing import List, Sequence, cast @@ -30,6 +29,7 @@ from cachetools.keys import hashkey from qlib.backtest.decision import Order, OrderDir +from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider from qlib.typehint import Literal DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] @@ -83,36 +83,16 @@ def _find_pickle(filename_without_suffix: Path) -> Path: @lru_cache(maxsize=10) # 10 * 40M = 400MB def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: - return pd.read_pickle(_find_pickle(filename_without_suffix)) + df = pd.read_pickle(_find_pickle(filename_without_suffix)) + index_cols = df.index.names + df = df.reset_index() + for date_col_name in ["date", "datetime"]: + if date_col_name in df: + df[date_col_name] = pd.to_datetime(df[date_col_name]) + df = df.set_index(index_cols) -class BaseIntradayBacktestData: - """ - Raw market data that is often used in backtesting (thus called BacktestData). - - Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest - data type. - """ - - @abstractmethod - def __repr__(self) -> str: - raise NotImplementedError - - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError - - @abstractmethod - def get_deal_price(self) -> pd.Series: - raise NotImplementedError - - @abstractmethod - def get_volume(self) -> pd.Series: - raise NotImplementedError - - @abstractmethod - def get_time_index(self) -> pd.DatetimeIndex: - raise NotImplementedError + return df class SimpleIntradayBacktestData(BaseIntradayBacktestData): @@ -120,7 +100,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData): def __init__( self, - data_dir: Path, + data_dir: Path | str, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", @@ -128,7 +108,7 @@ def __init__( ) -> None: super(SimpleIntradayBacktestData, self).__init__() - backtest = _read_pickle(data_dir / stock_id) + backtest = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] # No longer need for pandas >= 1.4 @@ -178,30 +158,19 @@ def get_time_index(self) -> pd.DatetimeIndex: return cast(pd.DatetimeIndex, self.data.index) -class IntradayProcessedData: - """Processed market data after data cleanup and feature engineering. - - It contains both processed data for "today" and "yesterday", as some algorithms - might use the market information of the previous day to assist decision making. - """ - - today: pd.DataFrame - """Processed data for "today". - Number of records must be ``time_length``, and columns must be ``feature_dim``.""" - - yesterday: pd.DataFrame - """Processed data for "yesterday". - Number of records must be ``time_length``, and columns must be ``feature_dim``.""" +class IntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle Dataset Handler style data.""" def __init__( self, - data_dir: Path, + data_dir: Path | str, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, ) -> None: - proc = _read_pickle(data_dir / stock_id) + proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) + # We have to infer the names here because, # unfortunately they are not included in the original data. cnames = _infer_processed_data_column_names(feature_dim) @@ -246,15 +215,15 @@ def load_simple_intraday_backtest_data( @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), # 100 * 50K = 5MB - key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date), + key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date), ) -def load_intraday_processed_data( +def load_pickled_intraday_processed_data( data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, -) -> IntradayProcessedData: +) -> BaseIntradayProcessedData: return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) @@ -267,6 +236,26 @@ def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) def load_teacher_action_data(teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> TeacherActionData: # type: ignore return TeacherActionData(teacher_action_file, stock_id, date) +class PickleProcessedDataProvider(ProcessedDataProvider): + def __init__(self, data_dir: Path) -> None: + super().__init__() + + self._data_dir = data_dir + + def get_data( + self, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> BaseIntradayProcessedData: + return load_pickled_intraday_processed_data( + data_dir=self._data_dir, + stock_id=stock_id, + date=date, + feature_dim=feature_dim, + time_index=time_index, + ) def load_orders( From 8c9edd67302e6c74783a7b2955e9244b1bca4971 Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Sun, 12 Feb 2023 14:59:23 +0800 Subject: [PATCH 05/13] update pickle_styled --- qlib/rl/data/pickle_styled.py | 89 +++++++++++++++-------------------- 1 file changed, 39 insertions(+), 50 deletions(-) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index c020a139b1..981d313837 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -19,7 +19,6 @@ from __future__ import annotations -from abc import abstractmethod from functools import lru_cache from pathlib import Path from typing import List, Sequence, cast @@ -30,6 +29,7 @@ from cachetools.keys import hashkey from qlib.backtest.decision import Order, OrderDir +from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider from qlib.typehint import Literal DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] @@ -83,36 +83,16 @@ def _find_pickle(filename_without_suffix: Path) -> Path: @lru_cache(maxsize=10) # 10 * 40M = 400MB def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: - return pd.read_pickle(_find_pickle(filename_without_suffix)) + df = pd.read_pickle(_find_pickle(filename_without_suffix)) + index_cols = df.index.names + df = df.reset_index() + for date_col_name in ["date", "datetime"]: + if date_col_name in df: + df[date_col_name] = pd.to_datetime(df[date_col_name]) + df = df.set_index(index_cols) -class BaseIntradayBacktestData: - """ - Raw market data that is often used in backtesting (thus called BacktestData). - - Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest - data type. - """ - - @abstractmethod - def __repr__(self) -> str: - raise NotImplementedError - - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError - - @abstractmethod - def get_deal_price(self) -> pd.Series: - raise NotImplementedError - - @abstractmethod - def get_volume(self) -> pd.Series: - raise NotImplementedError - - @abstractmethod - def get_time_index(self) -> pd.DatetimeIndex: - raise NotImplementedError + return df class SimpleIntradayBacktestData(BaseIntradayBacktestData): @@ -120,7 +100,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData): def __init__( self, - data_dir: Path, + data_dir: Path | str, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", @@ -128,7 +108,7 @@ def __init__( ) -> None: super(SimpleIntradayBacktestData, self).__init__() - backtest = _read_pickle(data_dir / stock_id) + backtest = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] # No longer need for pandas >= 1.4 @@ -178,30 +158,19 @@ def get_time_index(self) -> pd.DatetimeIndex: return cast(pd.DatetimeIndex, self.data.index) -class IntradayProcessedData: - """Processed market data after data cleanup and feature engineering. - - It contains both processed data for "today" and "yesterday", as some algorithms - might use the market information of the previous day to assist decision making. - """ - - today: pd.DataFrame - """Processed data for "today". - Number of records must be ``time_length``, and columns must be ``feature_dim``.""" - - yesterday: pd.DataFrame - """Processed data for "yesterday". - Number of records must be ``time_length``, and columns must be ``feature_dim``.""" +class IntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle Dataset Handler style data.""" def __init__( self, - data_dir: Path, + data_dir: Path | str, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, ) -> None: - proc = _read_pickle(data_dir / stock_id) + proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) + # We have to infer the names here because, # unfortunately they are not included in the original data. cnames = _infer_processed_data_column_names(feature_dim) @@ -246,15 +215,15 @@ def load_simple_intraday_backtest_data( @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), # 100 * 50K = 5MB - key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date), + key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date), ) -def load_intraday_processed_data( +def load_pickled_intraday_processed_data( data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, -) -> IntradayProcessedData: +) -> BaseIntradayProcessedData: return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) @@ -267,6 +236,26 @@ def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) def load_teacher_action_data(teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> TeacherActionData: # type: ignore return TeacherActionData(teacher_action_file, stock_id, date) +class PickleProcessedDataProvider(ProcessedDataProvider): + def __init__(self, data_dir: Path) -> None: + super().__init__() + + self._data_dir = data_dir + + def get_data( + self, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> BaseIntradayProcessedData: + return load_pickled_intraday_processed_data( + data_dir=self._data_dir, + stock_id=stock_id, + date=date, + feature_dim=feature_dim, + time_index=time_index, + ) def load_orders( From 216a6b1290e895e02f7425fdc06b1846b4f1af12 Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Sun, 12 Feb 2023 16:23:51 +0800 Subject: [PATCH 06/13] update data provider --- qlib/rl/data/pickle_styled.py | 28 +++++--- qlib/rl/order_execution/interpreter.py | 90 ++++++++++++++------------ 2 files changed, 67 insertions(+), 51 deletions(-) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 981d313837..a62503ac9d 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -227,15 +227,6 @@ def load_pickled_intraday_processed_data( return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) -class TeacherActionData: - teacher_action: pd.DataFrame - - def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> None: # type: ignore - self.teacher_action = pd.read_pickle(teacher_action_file).loc[pd.IndexSlice[stock_id, date]] # type: ignore - - -def load_teacher_action_data(teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> TeacherActionData: # type: ignore - return TeacherActionData(teacher_action_file, stock_id, date) class PickleProcessedDataProvider(ProcessedDataProvider): def __init__(self, data_dir: Path) -> None: super().__init__() @@ -258,6 +249,25 @@ def get_data( ) +class TeacherActionData: + teacher_action: pd.DataFrame + + def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> None: # type: ignore + self.teacher_action = pd.read_pickle(teacher_action_file).loc[pd.IndexSlice[stock_id, date]] # type: ignore + + +def load_teacher_action_data(teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> TeacherActionData: # type: ignore + return TeacherActionData(teacher_action_file, stock_id, date) + + +class TeacherActionDataProvider: + def __init__(self, teacher_action_file: Path) -> None: + self._teacher_action_file = teacher_action_file + + def get_data(self, stock_id: str, date: pd.Timestamp) -> TeacherActionData: + return load_teacher_action_data(self._teacher_action_file, stock_id, date) + + def load_orders( order_path: Path, start_time: pd.Timestamp = None, diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index e816987122..c00bde4caf 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -12,6 +12,7 @@ from qlib.constant import EPS from qlib.rl.data.base import ProcessedDataProvider +from qlib.rl.data.pickle_styled import TeacherActionDataProvider from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution.state import SAOEState from qlib.typehint import TypedDict @@ -147,33 +148,30 @@ def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame: class OracleObsInterpreter(FullHistoryStateInterpreter): def interpret(self, state: SAOEState) -> FullHistoryObs: - processed = pickle_styled.load_intraday_processed_data( - self.data_dir, - state.order.stock_id, - pd.Timestamp(state.order.start_time.date()), - self.data_dim, - state.ticks_index, + processed = self.processed_data_provider.get_data( + stock_id=state.order.stock_id, + date=pd.Timestamp(state.order.start_time.date()), + feature_dim=self.data_dim, + time_index=state.ticks_index, ) position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) position_history[0] = state.order.amount position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() - assert self.env is not None - return cast( FullHistoryObs, canonicalize( { - "data_processed": processed.today, - "data_processed_prev": processed.yesterday, - "acquiring": state.order.direction == state.order.BUY, - "cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1), - "cur_step": min(self.env.status["cur_step"], self.max_step - 1), - "num_step": self.max_step, - "target": state.order.amount, - "position": state.position, - "position_history": position_history[: self.max_step], + "data_processed": np.array(processed.today), + "data_processed_prev": np.array(processed.yesterday), + "acquiring": _to_int32(state.order.direction == state.order.BUY), + "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), + "cur_step": _to_int32(min(state.cur_step, self.max_step - 1)), + "num_step": _to_int32(self.max_step), + "target": _to_float32(state.order.amount), + "position": _to_float32(state.position), + "position_history": _to_float32(position_history[: self.max_step]), }, ), ) @@ -181,29 +179,37 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: class OPDObsInterpreter(FullHistoryStateInterpreter): def __init__( - self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int, teacher_action_file: Path + self, + max_step: int, + data_ticks: int, + data_dim: int, + processed_data_provider: dict | ProcessedDataProvider, + teacher_action_data_provider: dict | TeacherActionDataProvider, ) -> None: - super().__init__(data_dir, max_step, data_ticks, data_dim) - self.teacher_action_file = teacher_action_file + super().__init__(max_step, data_ticks, data_dim, processed_data_provider) + # self.teacher_action_file = teacher_action_file + self.teacher_action_data_provider = init_instance_by_config( + teacher_action_data_provider, accept_types=TeacherActionDataProvider + ) def interpret(self, state: SAOEState) -> OPDObs: - processed = pickle_styled.load_intraday_processed_data( - self.data_dir, - state.order.stock_id, - pd.Timestamp(state.order.start_time.date()), - self.data_dim, - state.ticks_index, + processed = self.processed_data_provider.get_data( + stock_id=state.order.stock_id, + date=pd.Timestamp(state.order.start_time.date()), + feature_dim=self.data_dim, + time_index=state.ticks_index, ) position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) position_history[0] = state.order.amount position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() + teacher_action = self.teacher_action_data_provider.get_data( + stock_id=state.order.stock_id, date=pd.Timestamp(state.order.start_time.date()) + ) - teacher_action = pickle_styled.load_teacher_action_data( - self.teacher_action_file, state.order.stock_id, pd.Timestamp(state.order.start_time.date()) - ).teacher_action - - assert self.env is not None + # teacher_action = pickle_styled.load_teacher_action_data( + # self.teacher_action_file, state.order.stock_id, pd.Timestamp(state.order.start_time.date()) + # ).teacher_action # The min, slice here are to make sure that indices fit into the range, # even after the final step of the simulator (in the done step), @@ -212,16 +218,16 @@ def interpret(self, state: SAOEState) -> OPDObs: OPDObs, canonicalize( { - "data_processed": self._mask_future_info(processed.today, state.cur_time), - "data_processed_prev": processed.yesterday, - "acquiring": state.order.direction == state.order.BUY, - "cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1), - "cur_step": min(self.env.status["cur_step"], self.max_step - 1), - "num_step": self.max_step, - "target": state.order.amount, - "position": state.position, - "position_history": position_history[: self.max_step], - "teacher_action": teacher_action.values[self.env.status["cur_step"]], + "data_processed": np.array(self._mask_future_info(processed.today, state.cur_time)), + "data_processed_prev": np.array(processed.yesterday), + "acquiring": _to_int32(state.order.direction == state.order.BUY), + "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), + "cur_step": _to_int32(min(state.cur_step, self.max_step - 1)), + "num_step": _to_int32(self.max_step), + "target": _to_float32(state.order.amount), + "position": _to_float32(state.position), + "position_history": _to_float32(position_history[: self.max_step]), + "teacher_action": _to_int32(teacher_action.values[state.cur_step]), }, ), ) @@ -239,7 +245,7 @@ def observation_space(self) -> spaces.Dict: "target": spaces.Box(-EPS, np.inf, shape=()), "position": spaces.Box(-EPS, np.inf, shape=()), "position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)), - "teacher_action": spaces.Box(-EPS, np.inf, shape=()), + "teacher_action": spaces.Box(-EPS, np.inf, shape=(), dtype=np.int32), } return spaces.Dict(space) From 52ee9993453bc5abfb37f8a45dbf9a89fbf6c198 Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Mon, 13 Feb 2023 13:09:54 +0800 Subject: [PATCH 07/13] logger --- qlib/rl/utils/log.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index de622d61b1..5dabb5126b 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -548,7 +548,7 @@ class ActionWriter(LogWriter): all_records: dict[str, list[Any]] - def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.DEBUG) -> None: super().__init__(loglevel) self.output_dir = output_dir self.output_dir.mkdir(exist_ok=True) @@ -558,16 +558,19 @@ def clear(self) -> None: self.all_records = defaultdict(list) def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: - for step_index, step_contents in enumerate(contents): for name, value in step_contents.items(): - if name in ["policy_act", "stock_id", "datetime"]: + if name == "policy_act": self.all_records[name].append(value) + if name in ["stock_id", "datetime"]: + self.all_records[name].extend([value] * len(contents)) self.all_records["step"].append(step_index) def on_env_all_done(self) -> None: # FIXME: this is temporary - pd.DataFrame.from_dict(self.all_records).to_pickle(self.output_dir / "action.pkl") + pd.DataFrame.from_dict(self.all_records).set_index(["stock_id", "datetime"]).sort_index().to_pickle( + self.output_dir / "action.pkl" + ) class TensorboardWriter(LogWriter): From 62228f50520c2daba30ab7b76db12627853a9d8b Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Mon, 13 Feb 2023 13:30:41 +0800 Subject: [PATCH 08/13] minor --- qlib/rl/data/pickle_styled.py | 20 -------------------- qlib/rl/utils/log.py | 4 +++- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 7e2d0739bf..a62503ac9d 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -258,26 +258,6 @@ def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) def load_teacher_action_data(teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> TeacherActionData: # type: ignore return TeacherActionData(teacher_action_file, stock_id, date) -class PickleProcessedDataProvider(ProcessedDataProvider): - def __init__(self, data_dir: Path) -> None: - super().__init__() - - self._data_dir = data_dir - - def get_data( - self, - stock_id: str, - date: pd.Timestamp, - feature_dim: int, - time_index: pd.Index, - ) -> BaseIntradayProcessedData: - return load_pickled_intraday_processed_data( - data_dir=self._data_dir, - stock_id=stock_id, - date=date, - feature_dim=feature_dim, - time_index=time_index, - ) class TeacherActionDataProvider: diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 5dabb5126b..0553f89ffa 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -562,7 +562,9 @@ def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str for name, value in step_contents.items(): if name == "policy_act": self.all_records[name].append(value) - if name in ["stock_id", "datetime"]: + if name == "datetime": + self.all_records["date"].extend([value.date()] * len(contents)) + if name == "stock_id": self.all_records[name].extend([value] * len(contents)) self.all_records["step"].append(step_index) From b44e03fb4a651608212da0d50bfd072376b18734 Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Mon, 13 Feb 2023 13:33:22 +0800 Subject: [PATCH 09/13] debug --- qlib/rl/utils/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 0553f89ffa..23f89be7be 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -570,7 +570,7 @@ def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str def on_env_all_done(self) -> None: # FIXME: this is temporary - pd.DataFrame.from_dict(self.all_records).set_index(["stock_id", "datetime"]).sort_index().to_pickle( + pd.DataFrame.from_dict(self.all_records).set_index(["stock_id", "date"]).sort_index().to_pickle( self.output_dir / "action.pkl" ) From 6aeb2674a0a6d1f1103d1e5ebe206a45acb0ae68 Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Mon, 13 Feb 2023 15:12:08 +0800 Subject: [PATCH 10/13] OPDT network --- qlib/rl/order_execution/network.py | 31 ++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py index d6a11189cf..f0116e52cc 100644 --- a/qlib/rl/order_execution/network.py +++ b/qlib/rl/order_execution/network.py @@ -119,6 +119,37 @@ def forward(self, batch: Batch) -> torch.Tensor: return self.fc(out) +class OracleRecurrent(Recurrent): + def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]: + bs, _, data_dim = obs["data_processed"].size() + data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1) + cur_step = obs["cur_step"].long() + bs_indices = torch.arange(bs, device=device) + + position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step] + steps = ( + torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float() + / obs["num_step"].unsqueeze(-1).float() + ) # [bs, num_step] + priv = torch.stack((position.float(), steps), -1) + + data_in = self.raw_fc(data) + data_out, _ = self.raw_rnn(data_in) + # get last minute output + data_out_slice = data_out[bs_indices, -1] + + priv_in = self.pri_fc(priv) + priv_out = self.pri_rnn(priv_in)[0] + priv_out = priv_out[bs_indices, cur_step] + + sources = [data_out_slice, priv_out] + + dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float()) + sources.append(dir_out) + + return sources, data_out + + class Attention(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() From 36dd47d4dfe8f11be0763ddeedf96d67db27e4c9 Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Mon, 13 Feb 2023 15:12:15 +0800 Subject: [PATCH 11/13] OPD policy --- qlib/rl/order_execution/policy.py | 100 +++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index 598e6b589a..4255abe499 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast +from typing import Any, List, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast import gym import numpy as np @@ -153,6 +153,104 @@ def __init__( set_weight(self, Trainer.get_policy_state_dict(weight_file)) +class OPD(PPO): + def __init__( + self, + network: nn.Module, + obs_space: gym.Space, + action_space: gym.Space, + lr: float, + weight_decay: float = 0.0, + discount_factor: float = 1.0, + max_grad_norm: float = 100.0, + reward_normalization: bool = True, + eps_clip: float = 0.3, + value_clip: bool = True, + vf_coef: float = 1.0, + gae_lambda: float = 1.0, + max_batch_size: int = 256, + deterministic_eval: bool = True, + dis_coef: float = 0.01, + weight_file: Optional[Path] = None, + ) -> None: + self._weight_dis = dis_coef + super().__init__( + network, + obs_space, + action_space, + lr, + weight_decay, + discount_factor, + max_grad_norm, + reward_normalization, + eps_clip, + value_clip, + vf_coef, + gae_lambda, + max_batch_size, + deterministic_eval, + weight_file, + ) + + def learn( # type: ignore + self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + ) -> Dict[str, List[float]]: + losses, clip_losses, vf_losses, dis_losses, ent_losses = [], [], [], [], [] + for step in range(repeat): + if self._recompute_adv and step > 0: + batch = self._compute_returns(batch, self._buffer, self._indices) + for minibatch in batch.split(batch_size, merge_last=True): + # calculate loss for actor + out = self(minibatch) + dist = out.dist + if self._norm_adv: + mean, std = minibatch.adv.mean(), minibatch.adv.std() + minibatch.adv = (minibatch.adv - mean) / std # per-batch norm + ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) + surr1 = ratio * minibatch.adv + surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * minibatch.adv + if self._dual_clip: + clip1 = torch.min(surr1, surr2) + clip2 = torch.max(clip1, self._dual_clip * minibatch.adv) + clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean() + else: + clip_loss = -torch.min(surr1, surr2).mean() + # calculate loss for critic + value = self.critic(minibatch.obs).flatten() + if self._value_clip: + v_clip = minibatch.v_s + (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip) + vf1 = (minibatch.returns - value).pow(2) + vf2 = (minibatch.returns - v_clip).pow(2) + vf_loss = torch.max(vf1, vf2).mean() + else: + vf_loss = (minibatch.returns - value).pow(2).mean() + # calculate distillation loss + teacher_action = minibatch.obs["teacher_action"] + logits = out.logits + dis_loss = nn.functional.nll_loss(logits.log(), teacher_action) + # calculate regularization and overall loss + ent_loss = dist.entropy().mean() + loss = clip_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss + self._weight_dis * dis_loss + self.optim.zero_grad() + loss.backward() + if self._grad_norm: # clip large gradient + nn.utils.clip_grad_norm_(self._actor_critic.parameters(), max_norm=self._grad_norm) + self.optim.step() + clip_losses.append(clip_loss.item()) + vf_losses.append(vf_loss.item()) + dis_losses.append(dis_loss.item()) + ent_losses.append(ent_loss.item()) + losses.append(loss.item()) + + return { + "loss": losses, + "loss/clip": clip_losses, + "loss/vf": vf_losses, + "loss/ent": ent_losses, + } + + # utilities: these should be put in a separate (common) file. # From aa42812111e95059f6dca2c3156ee1e748c5d89b Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Mon, 13 Feb 2023 15:12:21 +0800 Subject: [PATCH 12/13] data --- qlib/rl/data/pickle_styled.py | 5 ++++- qlib/rl/order_execution/interpreter.py | 6 +----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index a62503ac9d..f63fde6e46 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -251,9 +251,12 @@ def get_data( class TeacherActionData: teacher_action: pd.DataFrame + step: int def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> None: # type: ignore - self.teacher_action = pd.read_pickle(teacher_action_file).loc[pd.IndexSlice[stock_id, date]] # type: ignore + data = pd.read_pickle(teacher_action_file).loc[pd.IndexSlice[stock_id, date]] # type: ignore + self.teacher_action = data["policy_act"] + self.step = data["step"] def load_teacher_action_data(teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> TeacherActionData: # type: ignore diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index c00bde4caf..60160f4c91 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -205,11 +205,7 @@ def interpret(self, state: SAOEState) -> OPDObs: position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() teacher_action = self.teacher_action_data_provider.get_data( stock_id=state.order.stock_id, date=pd.Timestamp(state.order.start_time.date()) - ) - - # teacher_action = pickle_styled.load_teacher_action_data( - # self.teacher_action_file, state.order.stock_id, pd.Timestamp(state.order.start_time.date()) - # ).teacher_action + ).teacher_action # The min, slice here are to make sure that indices fit into the range, # even after the final step of the simulator (in the done step), From 28f0ff308a9ea7f5ebb9b3d8998e57e126117c70 Mon Sep 17 00:00:00 2001 From: Yuchen Fang Date: Tue, 14 Mar 2023 02:29:47 +0000 Subject: [PATCH 13/13] OPD entry --- qlib/rl/contrib/train_opdt.py | 252 +++++++++++++++++++++++++ qlib/rl/data/pickle_styled.py | 2 +- qlib/rl/order_execution/interpreter.py | 6 +- qlib/rl/order_execution/policy.py | 4 +- 4 files changed, 260 insertions(+), 4 deletions(-) create mode 100644 qlib/rl/contrib/train_opdt.py diff --git a/qlib/rl/contrib/train_opdt.py b/qlib/rl/contrib/train_opdt.py new file mode 100644 index 0000000000..5af351c66a --- /dev/null +++ b/qlib/rl/contrib/train_opdt.py @@ -0,0 +1,252 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import argparse +import os +import random +from pathlib import Path +from typing import cast, List, Optional + +import numpy as np +import pandas as pd +import qlib +import torch +import yaml +from qlib.backtest import Order +from qlib.backtest.decision import OrderDir +from qlib.constant import ONE_MIN +from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.order_execution import SingleAssetOrderExecutionSimple +from qlib.rl.reward import Reward +from qlib.rl.trainer import Checkpoint, backtest, train +from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter +from qlib.rl.utils.log import CsvWriter, ActionWriter +from qlib.utils import init_instance_by_config +from tianshou.policy import BasePolicy +from torch import nn +from torch.utils.data import Dataset + + +def seed_everything(seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +def _read_orders(order_dir: Path) -> pd.DataFrame: + if os.path.isfile(order_dir): + return pd.read_pickle(order_dir) + else: + orders = [] + for file in order_dir.iterdir(): + order_data = pd.read_pickle(file) + orders.append(order_data) + return pd.concat(orders) + + +class LazyLoadDataset(Dataset): + def __init__( + self, + order_file_path: Path, + data_dir: Path, + default_start_time_index: int, + default_end_time_index: int, + ) -> None: + self._default_start_time_index = default_start_time_index + self._default_end_time_index = default_end_time_index + + self._order_file_path = order_file_path + self._order_df = _read_orders(order_file_path).reset_index() + + self._data_dir = data_dir + self._ticks_index: Optional[pd.DatetimeIndex] = None + + def __len__(self) -> int: + return len(self._order_df) + + def __getitem__(self, index: int) -> Order: + row = self._order_df.iloc[index] + date = pd.Timestamp(str(row["date"])) + + if self._ticks_index is None: + # TODO: We only load ticks index once based on the assumption that ticks index of different dates + # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index + # TODO: of all dates. + backtest_data = load_simple_intraday_backtest_data( + data_dir=self._data_dir, + stock_id=row["instrument"], + date=date, + ) + self._ticks_index = [t - date for t in backtest_data.get_time_index()] + + order = Order( + stock_id=row["instrument"], + amount=row["amount"], + direction=OrderDir(int(row["order_type"])), + start_time=date + self._ticks_index[self._default_start_time_index], + end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, + ) + + return order + + +def train_and_test( + env_config: dict, + simulator_config: dict, + trainer_config: dict, + data_config: dict, + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + policy: BasePolicy, + reward: Reward, + run_backtest: bool, +) -> None: + qlib.init() + + order_root_path = Path(data_config["source"]["order_dir"]) + + data_granularity = simulator_config.get("data_granularity", 1) + + def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: + return SingleAssetOrderExecutionSimple( + order=order, + data_dir=Path(data_config["source"]["data_dir"]), + ticks_per_step=simulator_config["time_per_step"], + data_granularity=data_granularity, + deal_price_type=data_config["source"].get("deal_price_column", "close"), + vol_threshold=simulator_config["vol_limit"], + ) + + assert data_config["source"]["default_start_time_index"] % data_granularity == 0 + assert data_config["source"]["default_end_time_index"] % data_granularity == 0 + + train_dataset, valid_dataset, test_dataset = [ + LazyLoadDataset( + order_file_path=order_root_path / tag, + data_dir=Path(data_config["source"]["data_dir"]), + default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + ) + for tag in ("train", "valid", "all") + ] + + if "checkpoint_path" in trainer_config: + callbacks: List[Callback] = [] + callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) + callbacks.append( + Checkpoint( + dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", + every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), + save_latest="copy", + ), + ) + if "earlystop_patience" in trainer_config: + callbacks.append( + EarlyStopping( + patience=trainer_config["earlystop_patience"], + monitor="val/pa", + ) + ) + + trainer_kwargs = { + "max_iters": trainer_config["max_epoch"], + "finite_env_type": env_config["parallel_mode"], + "concurrency": env_config["concurrency"], + "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), + "callbacks": callbacks, + } + vessel_kwargs = { + "episode_per_iter": trainer_config["episode_per_collect"], + "update_kwargs": { + "batch_size": trainer_config["batch_size"], + "repeat": trainer_config["repeat_per_collect"], + }, + "val_initial_states": valid_dataset, + } + + train( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + reward=reward, + initial_states=cast(List[Order], train_dataset), + trainer_kwargs=trainer_kwargs, + vessel_kwargs=vessel_kwargs, + ) + + if run_backtest: + backtest( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + initial_states=test_dataset, + policy=policy, + logger=[CsvWriter(Path(trainer_config["checkpoint_path"])), ActionWriter(Path(trainer_config["checkpoint_path"]))], + # logger = CsvWriter(Path(trainer_config["checkpoint_path"]), loglevel=10), + reward=reward, + finite_env_type=trainer_kwargs["finite_env_type"], + concurrency=trainer_kwargs["concurrency"], + ) + + +def main(config: dict, run_backtest: bool) -> None: + if "seed" in config["runtime"]: + seed_everything(config["runtime"]["seed"]) + + state_config = config["state_interpreter"] + state_interpreter: StateInterpreter = init_instance_by_config(state_config) + + action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) + reward: Reward = init_instance_by_config(config["reward"]) + + # Create torch network + if "kwargs" not in config["network"]: + config["network"]["kwargs"] = {} + config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) + network: nn.Module = init_instance_by_config(config["network"]) + + # Create policy + config["policy"]["kwargs"].update( + { + "network": network, + "obs_space": state_interpreter.observation_space, + "action_space": action_interpreter.action_space, + } + ) + policy: BasePolicy = init_instance_by_config(config["policy"]) + + use_cuda = config["runtime"].get("use_cuda", False) + if use_cuda: + policy.cuda() + + train_and_test( + env_config=config["env"], + simulator_config=config["simulator"], + data_config=config["data"], + trainer_config=config["trainer"], + action_interpreter=action_interpreter, + state_interpreter=state_interpreter, + policy=policy, + reward=reward, + run_backtest=run_backtest, + ) + + +if __name__ == "__main__": + import warnings + + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished") + args = parser.parse_args() + + with open(args.config_path, "r") as input_stream: + config = yaml.safe_load(input_stream) + + main(config, run_backtest=args.run_backtest) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index f63fde6e46..1ff449111e 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -254,7 +254,7 @@ class TeacherActionData: step: int def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> None: # type: ignore - data = pd.read_pickle(teacher_action_file).loc[pd.IndexSlice[stock_id, date]] # type: ignore + data = pd.read_pickle(teacher_action_file).loc[pd.IndexSlice[stock_id, date.date()]] # type: ignore self.teacher_action = data["policy_act"] self.step = data["step"] diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 60160f4c91..02f5d14e56 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -206,6 +206,10 @@ def interpret(self, state: SAOEState) -> OPDObs: teacher_action = self.teacher_action_data_provider.get_data( stock_id=state.order.stock_id, date=pd.Timestamp(state.order.start_time.date()) ).teacher_action + try: + this_teacher_action = teacher_action.values[state.cur_step] + except IndexError: + this_teacher_action = 0 # The min, slice here are to make sure that indices fit into the range, # even after the final step of the simulator (in the done step), @@ -223,7 +227,7 @@ def interpret(self, state: SAOEState) -> OPDObs: "target": _to_float32(state.order.amount), "position": _to_float32(state.position), "position_history": _to_float32(position_history[: self.max_step]), - "teacher_action": _to_int32(teacher_action.values[state.cur_step]), + "teacher_action": _to_int32(this_teacher_action), }, ), ) diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index 4255abe499..5991e04fc0 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, List, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast +from typing import Any, Dict, Generator, Iterable, List, Optional, OrderedDict, Tuple, cast import gym import numpy as np @@ -226,7 +226,7 @@ def learn( # type: ignore else: vf_loss = (minibatch.returns - value).pow(2).mean() # calculate distillation loss - teacher_action = minibatch.obs["teacher_action"] + teacher_action = torch.tensor(minibatch.obs["teacher_action"]).long() logits = out.logits dis_loss = nn.functional.nll_loss(logits.log(), teacher_action) # calculate regularization and overall loss