136 lines
4.6 KiB
Python
136 lines
4.6 KiB
Python
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
from qlib.contrib.data.handler import _DEFAULT_LEARN_PROCESSORS, check_transform_proc
|
|
from qlib.data.dataset.handler import DataHandlerLP
|
|
|
|
|
|
FieldsConfig = Union[Sequence[str], Tuple[Sequence[str], Sequence[str]]]
|
|
|
|
|
|
class ExpressionDataHandlerLP(DataHandlerLP):
|
|
"""DataHandlerLP that accepts feature/label expressions via init args."""
|
|
|
|
def __init__(
|
|
self,
|
|
instruments="all",
|
|
start_time=None,
|
|
end_time=None,
|
|
freq="day",
|
|
feature_map: Optional[Dict[str, str]] = None,
|
|
feature_exprs: Optional[FieldsConfig] = None,
|
|
feature_names: Optional[Sequence[str]] = None,
|
|
label_map: Optional[Dict[str, str]] = None,
|
|
label_exprs: Optional[FieldsConfig] = None,
|
|
label_names: Optional[Sequence[str]] = None,
|
|
infer_processors: Optional[List] = None,
|
|
learn_processors: Optional[List] = None,
|
|
fit_start_time=None,
|
|
fit_end_time=None,
|
|
process_type=DataHandlerLP.PTYPE_A,
|
|
filter_pipe=None,
|
|
inst_processors=None,
|
|
**kwargs,
|
|
):
|
|
legacy_label = kwargs.pop("label", None)
|
|
kwargs.pop("feature", None)
|
|
|
|
if label_map is None and label_exprs is None and legacy_label is not None:
|
|
label_exprs = legacy_label
|
|
|
|
if infer_processors is None:
|
|
infer_processors = []
|
|
if learn_processors is None:
|
|
learn_processors = _DEFAULT_LEARN_PROCESSORS
|
|
|
|
infer_processors = check_transform_proc(
|
|
infer_processors, fit_start_time, fit_end_time
|
|
)
|
|
learn_processors = check_transform_proc(
|
|
learn_processors, fit_start_time, fit_end_time
|
|
)
|
|
|
|
feature_config = self._build_fields_config(
|
|
fields_map=feature_map,
|
|
exprs=feature_exprs,
|
|
names=feature_names,
|
|
default_exprs=["$close"],
|
|
default_names=["CLOSE"],
|
|
group_name="feature",
|
|
)
|
|
label_config = self._build_fields_config(
|
|
fields_map=label_map,
|
|
exprs=label_exprs,
|
|
names=label_names,
|
|
default_exprs=["Ref($close, -2) / Ref($close, -1) - 1"],
|
|
default_names=["LABEL0"],
|
|
group_name="label",
|
|
)
|
|
|
|
data_loader = {
|
|
"class": "QlibDataLoader",
|
|
"kwargs": {
|
|
"config": {
|
|
"feature": feature_config,
|
|
"label": label_config,
|
|
},
|
|
"filter_pipe": filter_pipe,
|
|
"freq": freq,
|
|
"inst_processors": inst_processors,
|
|
},
|
|
}
|
|
|
|
super().__init__(
|
|
instruments=instruments,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
data_loader=data_loader,
|
|
infer_processors=infer_processors,
|
|
learn_processors=learn_processors,
|
|
process_type=process_type,
|
|
**kwargs,
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_fields_config(
|
|
fields_map: Optional[Dict[str, str]],
|
|
exprs: Optional[FieldsConfig],
|
|
names: Optional[Sequence[str]],
|
|
default_exprs: Sequence[str],
|
|
default_names: Sequence[str],
|
|
group_name: str,
|
|
):
|
|
if fields_map is not None:
|
|
if not isinstance(fields_map, dict) or len(fields_map) == 0:
|
|
raise ValueError(f"{group_name}_map must be a non-empty dict")
|
|
name_list = list(fields_map.keys())
|
|
expr_list = list(fields_map.values())
|
|
if exprs is not None or names is not None:
|
|
raise ValueError(
|
|
f"Please use either {group_name}_map or {group_name}_exprs/{group_name}_names"
|
|
)
|
|
elif exprs is None:
|
|
expr_list = list(default_exprs)
|
|
name_list = list(default_names)
|
|
elif (
|
|
isinstance(exprs, (list, tuple))
|
|
and len(exprs) == 2
|
|
and isinstance(exprs[0], (list, tuple))
|
|
and isinstance(exprs[1], (list, tuple))
|
|
and names is None
|
|
):
|
|
expr_list = list(exprs[0])
|
|
name_list = list(exprs[1])
|
|
else:
|
|
expr_list = list(exprs)
|
|
name_list = list(names) if names is not None else list(expr_list)
|
|
|
|
if len(expr_list) == 0:
|
|
raise ValueError(f"{group_name}_exprs must not be empty")
|
|
if len(expr_list) != len(name_list):
|
|
raise ValueError(
|
|
f"{group_name} expressions and names length mismatch: "
|
|
f"{len(expr_list)} != {len(name_list)}"
|
|
)
|
|
|
|
return expr_list, name_list
|