from typing import Dict, List, Optional, Sequence, Tuple, Union from qlib.contrib.data.handler import _DEFAULT_LEARN_PROCESSORS, check_transform_proc from qlib.data.dataset.handler import DataHandlerLP FieldsConfig = Union[Sequence[str], Tuple[Sequence[str], Sequence[str]]] class ExpressionDataHandlerLP(DataHandlerLP): """DataHandlerLP that accepts feature/label expressions via init args.""" def __init__( self, instruments="all", start_time=None, end_time=None, freq="day", feature_map: Optional[Dict[str, str]] = None, feature_exprs: Optional[FieldsConfig] = None, feature_names: Optional[Sequence[str]] = None, label_map: Optional[Dict[str, str]] = None, label_exprs: Optional[FieldsConfig] = None, label_names: Optional[Sequence[str]] = None, infer_processors: Optional[List] = None, learn_processors: Optional[List] = None, fit_start_time=None, fit_end_time=None, process_type=DataHandlerLP.PTYPE_A, filter_pipe=None, inst_processors=None, **kwargs, ): legacy_label = kwargs.pop("label", None) kwargs.pop("feature", None) if label_map is None and label_exprs is None and legacy_label is not None: label_exprs = legacy_label if infer_processors is None: infer_processors = [] if learn_processors is None: learn_processors = _DEFAULT_LEARN_PROCESSORS infer_processors = check_transform_proc( infer_processors, fit_start_time, fit_end_time ) learn_processors = check_transform_proc( learn_processors, fit_start_time, fit_end_time ) feature_config = self._build_fields_config( fields_map=feature_map, exprs=feature_exprs, names=feature_names, default_exprs=["$close"], default_names=["CLOSE"], group_name="feature", ) label_config = self._build_fields_config( fields_map=label_map, exprs=label_exprs, names=label_names, default_exprs=["Ref($close, -2) / Ref($close, -1) - 1"], default_names=["LABEL0"], group_name="label", ) data_loader = { "class": "QlibDataLoader", "kwargs": { "config": { "feature": feature_config, "label": label_config, }, "filter_pipe": filter_pipe, "freq": freq, "inst_processors": inst_processors, }, } super().__init__( instruments=instruments, start_time=start_time, end_time=end_time, data_loader=data_loader, infer_processors=infer_processors, learn_processors=learn_processors, process_type=process_type, **kwargs, ) @staticmethod def _build_fields_config( fields_map: Optional[Dict[str, str]], exprs: Optional[FieldsConfig], names: Optional[Sequence[str]], default_exprs: Sequence[str], default_names: Sequence[str], group_name: str, ): if fields_map is not None: if not isinstance(fields_map, dict) or len(fields_map) == 0: raise ValueError(f"{group_name}_map must be a non-empty dict") name_list = list(fields_map.keys()) expr_list = list(fields_map.values()) if exprs is not None or names is not None: raise ValueError( f"Please use either {group_name}_map or {group_name}_exprs/{group_name}_names" ) elif exprs is None: expr_list = list(default_exprs) name_list = list(default_names) elif ( isinstance(exprs, (list, tuple)) and len(exprs) == 2 and isinstance(exprs[0], (list, tuple)) and isinstance(exprs[1], (list, tuple)) and names is None ): expr_list = list(exprs[0]) name_list = list(exprs[1]) else: expr_list = list(exprs) name_list = list(names) if names is not None else list(expr_list) if len(expr_list) == 0: raise ValueError(f"{group_name}_exprs must not be empty") if len(expr_list) != len(name_list): raise ValueError( f"{group_name} expressions and names length mismatch: " f"{len(expr_list)} != {len(name_list)}" ) return expr_list, name_list