This commit is contained in:
yhydev
2026-03-18 09:29:15 +08:00
commit 52187fc1cf
6 changed files with 159 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@@ -0,0 +1,10 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.13

0
README.md Normal file
View File

6
main.py Normal file
View File

@@ -0,0 +1,6 @@
def main():
print("Hello from qlib-ext!")
if __name__ == "__main__":
main()

7
pyproject.toml Normal file
View File

@@ -0,0 +1,7 @@
[project]
name = "qlib-ext"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = []

View File

@@ -0,0 +1,135 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union
from qlib.contrib.data.handler import _DEFAULT_LEARN_PROCESSORS, check_transform_proc
from qlib.data.dataset.handler import DataHandlerLP
FieldsConfig = Union[Sequence[str], Tuple[Sequence[str], Sequence[str]]]
class ExpressionDataHandlerLP(DataHandlerLP):
"""DataHandlerLP that accepts feature/label expressions via init args."""
def __init__(
self,
instruments="all",
start_time=None,
end_time=None,
freq="day",
feature_map: Optional[Dict[str, str]] = None,
feature_exprs: Optional[FieldsConfig] = None,
feature_names: Optional[Sequence[str]] = None,
label_map: Optional[Dict[str, str]] = None,
label_exprs: Optional[FieldsConfig] = None,
label_names: Optional[Sequence[str]] = None,
infer_processors: Optional[List] = None,
learn_processors: Optional[List] = None,
fit_start_time=None,
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processors=None,
**kwargs,
):
legacy_label = kwargs.pop("label", None)
kwargs.pop("feature", None)
if label_map is None and label_exprs is None and legacy_label is not None:
label_exprs = legacy_label
if infer_processors is None:
infer_processors = []
if learn_processors is None:
learn_processors = _DEFAULT_LEARN_PROCESSORS
infer_processors = check_transform_proc(
infer_processors, fit_start_time, fit_end_time
)
learn_processors = check_transform_proc(
learn_processors, fit_start_time, fit_end_time
)
feature_config = self._build_fields_config(
fields_map=feature_map,
exprs=feature_exprs,
names=feature_names,
default_exprs=["$close"],
default_names=["CLOSE"],
group_name="feature",
)
label_config = self._build_fields_config(
fields_map=label_map,
exprs=label_exprs,
names=label_names,
default_exprs=["Ref($close, -2) / Ref($close, -1) - 1"],
default_names=["LABEL0"],
group_name="label",
)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
"feature": feature_config,
"label": label_config,
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processors": inst_processors,
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
process_type=process_type,
**kwargs,
)
@staticmethod
def _build_fields_config(
fields_map: Optional[Dict[str, str]],
exprs: Optional[FieldsConfig],
names: Optional[Sequence[str]],
default_exprs: Sequence[str],
default_names: Sequence[str],
group_name: str,
):
if fields_map is not None:
if not isinstance(fields_map, dict) or len(fields_map) == 0:
raise ValueError(f"{group_name}_map must be a non-empty dict")
name_list = list(fields_map.keys())
expr_list = list(fields_map.values())
if exprs is not None or names is not None:
raise ValueError(
f"Please use either {group_name}_map or {group_name}_exprs/{group_name}_names"
)
elif exprs is None:
expr_list = list(default_exprs)
name_list = list(default_names)
elif (
isinstance(exprs, (list, tuple))
and len(exprs) == 2
and isinstance(exprs[0], (list, tuple))
and isinstance(exprs[1], (list, tuple))
and names is None
):
expr_list = list(exprs[0])
name_list = list(exprs[1])
else:
expr_list = list(exprs)
name_list = list(names) if names is not None else list(expr_list)
if len(expr_list) == 0:
raise ValueError(f"{group_name}_exprs must not be empty")
if len(expr_list) != len(name_list):
raise ValueError(
f"{group_name} expressions and names length mismatch: "
f"{len(expr_list)} != {len(name_list)}"
)
return expr_list, name_list