提交oss代码
This commit is contained in:
269
cloudbt/run_cloudbt.py
Normal file
269
cloudbt/run_cloudbt.py
Normal file
@@ -0,0 +1,269 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
from typing import List, Optional
|
||||
import requests
|
||||
from myscripts import working_tool
|
||||
import dotenv
|
||||
import json
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
def get_pairs_from_data_dir(config_path: str, timeframe: str) -> List[str]:
|
||||
"""
|
||||
Get valid pairs from the data directory, using configuration from config file
|
||||
"""
|
||||
try:
|
||||
# Read configuration from config file
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
exchange_name = config.get('exchange', {}).get('name', 'binance')
|
||||
trading_mode = config.get('trading_mode', 'futures')
|
||||
|
||||
# Build data directory path
|
||||
data_dir = Path('user_data') / 'data' / exchange_name / trading_mode
|
||||
|
||||
if not data_dir.exists():
|
||||
print(f"Warning: Data directory does not exist: {data_dir}")
|
||||
return []
|
||||
|
||||
# Build filename pattern
|
||||
pattern = f"*-{timeframe}-{trading_mode}.feather"
|
||||
|
||||
# Find all matching files
|
||||
pairs = []
|
||||
for file_path in data_dir.glob(pattern):
|
||||
# Extract pair from filename
|
||||
filename = file_path.stem # Remove extension
|
||||
# Remove timeframe and trading_mode suffix
|
||||
pair_part = filename.rsplit(f"-{timeframe}-{trading_mode}", 1)[0]
|
||||
|
||||
# Convert from filename format to standard format
|
||||
# Example: TRX_USDT_USDT-1m-futures.feather -> TRX/USDT:USDT
|
||||
parts = pair_part.split('_')
|
||||
if len(parts) >= 2:
|
||||
base = parts[0]
|
||||
quote = parts[1] if len(parts) > 1 else 'USDT'
|
||||
# For futures, format is BASE/QUOTE:QUOTE
|
||||
if trading_mode == 'futures':
|
||||
pair = f"{base}/{quote}:{quote}"
|
||||
else:
|
||||
pair = f"{base}/{quote}"
|
||||
pairs.append(pair)
|
||||
|
||||
# Deduplicate and sort
|
||||
pairs = sorted(list(set(pairs)))
|
||||
|
||||
print(f"Found {len(pairs)} pairs from data directory {data_dir} (timeframe: {timeframe})")
|
||||
|
||||
return pairs
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: Failed to get pairs from data directory: {e}")
|
||||
return []
|
||||
|
||||
def get_pair_data_file_path(pair: str, config_path: str, timeframe: str) -> Optional[Path]:
|
||||
"""
|
||||
Get the data file path for a specific pair
|
||||
"""
|
||||
try:
|
||||
# Read configuration from config file
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
exchange_name = config.get('exchange', {}).get('name', 'binance')
|
||||
trading_mode = config.get('trading_mode', 'futures')
|
||||
|
||||
# Build data directory path
|
||||
data_dir = Path('user_data') / 'data' / exchange_name / trading_mode
|
||||
|
||||
# Convert pair to filename format
|
||||
# Example: BARD/USDT:USDT -> BARD_USDT_USDT-1m-futures.feather
|
||||
pair_normalized = pair.replace('/', '_').replace(':', '_')
|
||||
filename = f"{pair_normalized}-{timeframe}-{trading_mode}.feather"
|
||||
file_path = data_dir / filename
|
||||
|
||||
return file_path if file_path.exists() else None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to get data file path for pair {pair}: {e}")
|
||||
return None
|
||||
|
||||
def filter_pairs_by_timerange(pairs: List[str], timerange: str, config_path: str, timeframe: str) -> List[str]:
|
||||
"""Filter pairs based on timerange, checking if data exists for the given timerange"""
|
||||
print(f"Checking pair data availability for timerange: {timerange}...")
|
||||
valid_pairs = []
|
||||
invalid_pairs = []
|
||||
|
||||
try:
|
||||
# Parse timerange string into start and end dates
|
||||
start_str, end_str = timerange.split('-')
|
||||
start_date = datetime.datetime.strptime(start_str, '%Y%m%d')
|
||||
end_date = datetime.datetime.strptime(end_str, '%Y%m%d')
|
||||
|
||||
# Convert to pandas Timestamps for comparison
|
||||
start_ts = pd.Timestamp(start_date)
|
||||
end_ts = pd.Timestamp(end_date)
|
||||
|
||||
for pair in pairs:
|
||||
file_path = get_pair_data_file_path(pair, config_path, timeframe)
|
||||
if file_path is not None:
|
||||
try:
|
||||
# Read only the date column to check data availability
|
||||
df = pd.read_feather(file_path, columns=['date'])
|
||||
if not df.empty:
|
||||
# Convert date column to datetime
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
|
||||
# Remove timezone if present
|
||||
if df['date'].dt.tz is not None:
|
||||
df['date'] = df['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||
|
||||
# Check if there's data within the timerange
|
||||
mask = (df['date'] >= start_ts) & (df['date'] <= end_ts)
|
||||
if any(mask):
|
||||
valid_pairs.append(pair)
|
||||
else:
|
||||
invalid_pairs.append(pair)
|
||||
else:
|
||||
invalid_pairs.append(pair)
|
||||
except Exception as e:
|
||||
print(f"Warning: Error reading data for pair {pair}: {e}")
|
||||
invalid_pairs.append(pair)
|
||||
else:
|
||||
invalid_pairs.append(pair)
|
||||
except Exception as e:
|
||||
print(f"Error parsing timerange: {e}")
|
||||
return []
|
||||
|
||||
if invalid_pairs:
|
||||
print(f"Filtered out {len(invalid_pairs)} pairs with no data in the specified timerange:")
|
||||
for pair in invalid_pairs[:10]: # Show only first 10
|
||||
print(f" - {pair}")
|
||||
if len(invalid_pairs) > 10:
|
||||
print(f" ... and {len(invalid_pairs) - 10} more pairs")
|
||||
|
||||
print(f"Kept {len(valid_pairs)} pairs with available data")
|
||||
return valid_pairs
|
||||
|
||||
def split_pairs(pairs: List[str], jobs: int = None, max_pairs: int = None) -> List[List[str]]:
|
||||
"""
|
||||
Split pairs into chunks based on jobs or max_pairs
|
||||
|
||||
Args:
|
||||
pairs: List of pairs to split
|
||||
jobs: Number of chunks to create (if max_pairs is specified, this will be ignored)
|
||||
max_pairs: Maximum pairs per chunk (higher priority than jobs)
|
||||
|
||||
Returns:
|
||||
List of chunks, where each chunk is a list of pairs
|
||||
"""
|
||||
total_pairs = len(pairs)
|
||||
|
||||
if max_pairs is not None and max_pairs > 0:
|
||||
# Calculate number of chunks based on max_pairs per chunk
|
||||
num_jobs = (total_pairs + max_pairs - 1) // max_pairs # 向上取整
|
||||
print(f"根据max-pairs={max_pairs}计算,需要{num_jobs}个job")
|
||||
elif jobs is not None and jobs > 0:
|
||||
# Use specified number of jobs
|
||||
num_jobs = jobs
|
||||
else:
|
||||
# Default to 1 job if neither is specified
|
||||
num_jobs = 1
|
||||
|
||||
# Split pairs into chunks
|
||||
chunks = []
|
||||
if num_jobs <= 0:
|
||||
num_jobs = 1
|
||||
|
||||
for i in range(num_jobs):
|
||||
start = i * (total_pairs // num_jobs) + min(i, total_pairs % num_jobs)
|
||||
end = start + (total_pairs // num_jobs) + (1 if i < total_pairs % num_jobs else 0)
|
||||
chunks.append(pairs[start:end])
|
||||
|
||||
return chunks
|
||||
|
||||
def submit_job(timeframe: str, timerange: str, pairs_chunk: List[str],
|
||||
strategy: str,
|
||||
start_datetime: str,
|
||||
job_id: int,
|
||||
working_url: str) -> None:
|
||||
"""Submit job function (placeholder implementation)"""
|
||||
script = generate_shell_script(timeframe, timerange,
|
||||
pairs_chunk, strategy, start_datetime, job_id, working_url)
|
||||
requests.post(
|
||||
"https://prefect.oopsapi.com/api/deployments/dac4e321-cc60-4ca2-8aba-ee389d395ae9/create_flow_run",
|
||||
json={
|
||||
"name": f"backtest-{strategy}-{job_id}",
|
||||
"parameters": {"shell_script": script},
|
||||
"tags": [strategy, timeframe, timerange]
|
||||
}
|
||||
)
|
||||
|
||||
def generate_shell_script(timeframe: str, timerange: str, pairs_chunk: List[str],
|
||||
strategy: str,
|
||||
start_datetime: str,
|
||||
job_id: int,
|
||||
working_url: str) -> str:
|
||||
"""
|
||||
Generate a shell script for the given timeframe, timerange, and pairs chunk.
|
||||
"""
|
||||
return f"""
|
||||
#!/bin/bash
|
||||
cd $WORKING_DIR
|
||||
echo "当前工作目录: $(pwd)"
|
||||
echo "下载工作空间数据"
|
||||
working_tool download {working_url}
|
||||
echo "开始回测"
|
||||
mkdir -p user_data/backtest_results/{start_datetime}
|
||||
freqtrade backtesting \
|
||||
--timeframe {timeframe} \
|
||||
--timerange {timerange} \
|
||||
--pairs {','.join(pairs_chunk)} \
|
||||
--strategy {strategy} \
|
||||
--export trades \
|
||||
--export-directory user_data/backtest_results/{start_datetime}/{job_id}.json \
|
||||
--no-color 1>user_data/backtest_results/{start_datetime}/job_{job_id}.log 2>&1
|
||||
mc cp user_data/backtest_results/{start_datetime}/job_{job_id}.log oss/backtest/{start_datetime}
|
||||
"""
|
||||
|
||||
def main():
|
||||
dotenv.load_dotenv()
|
||||
zip_password = os.environ.get("ZIP_PASSWORD", "")
|
||||
working_url = working_tool.upload_working_files(zip_password)
|
||||
parser = argparse.ArgumentParser(description="Backtest script with pair splitting")
|
||||
parser.add_argument("--timeframe", required=True, help="Timeframe for backtest (e.g., 1m, 1h)")
|
||||
parser.add_argument("--timerange", required=True, help="Timerange for backtest (e.g., 20210101-20211231)")
|
||||
parser.add_argument("--pairs", default=None, help="Comma-separated list of pairs")
|
||||
parser.add_argument("--max-pairs", type=int, help="Maximum pairs per job (higher priority than --jobs)")
|
||||
parser.add_argument("--jobs", type=int, help="Number of jobs to create (ignored if --max-pairs is specified)")
|
||||
parser.add_argument("--strategy", required=True, help="Strategy to use for backtest")
|
||||
parser.add_argument("--config", type=str, default="user_data/config.json", help="Config file path, default: user_data/config.json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get pairs list
|
||||
if args.pairs:
|
||||
pairs = args.pairs.split(",")
|
||||
print(f"Using provided {len(pairs)} pairs")
|
||||
else:
|
||||
pairs = get_pairs_from_data_dir(args.config, args.timeframe)
|
||||
|
||||
# Filter pairs by timerange
|
||||
filtered_pairs = filter_pairs_by_timerange(pairs, args.timerange, args.config, args.timeframe)
|
||||
|
||||
# Split pairs into chunks
|
||||
pair_chunks = split_pairs(filtered_pairs, args.jobs, args.max_pairs)
|
||||
|
||||
dt = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
|
||||
# Submit jobs
|
||||
for i, chunk in enumerate(pair_chunks):
|
||||
if chunk: # 检查分块是否为空
|
||||
submit_job(args.timeframe, args.timerange, chunk,
|
||||
args.strategy, dt, i, working_url)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user