Add process_symbol function to handle kline data for specific symbol
This commit is contained in:
@@ -7,6 +7,7 @@ from psycopg2.extras import execute_values
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import xml.etree.ElementTree as ET
|
||||
from download_unzip_csv import download_unzip_csv
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
@@ -229,6 +230,81 @@ def download_kline_data_by_url(url):
|
||||
logger.error(f"Failed to download {url}: {e}")
|
||||
return None
|
||||
|
||||
def process_symbol(symbol, interval=INTERVAL):
|
||||
"""
|
||||
处理指定交易对的所有K线数据,包括下载、解析、合并和插入数据库
|
||||
|
||||
参数:
|
||||
symbol: 交易对,例如: "BTCUSDT"
|
||||
interval: 时间间隔,例如: "1d",默认使用全局INTERVAL
|
||||
|
||||
返回:
|
||||
pandas.DataFrame: 合并后的K线数据
|
||||
"""
|
||||
logger.info(f"Processing symbol: {symbol}, interval: {interval}")
|
||||
|
||||
# 组装S3列表URL
|
||||
s3_url = f"https://s3-ap-northeast-1.amazonaws.com/data.binance.vision?delimiter=/&prefix=data/futures/um/monthly/klines/{symbol}/{interval}/"
|
||||
|
||||
# 获取所有可下载文件URL
|
||||
file_urls = list_s3_files(s3_url)
|
||||
|
||||
if not file_urls:
|
||||
logger.warning(f"No files found for {symbol}-{interval}")
|
||||
return None
|
||||
|
||||
# 合并所有DataFrame
|
||||
all_dfs = []
|
||||
for file_url in file_urls:
|
||||
if not file_url.endswith('.zip'):
|
||||
continue
|
||||
|
||||
try:
|
||||
# 调用download_unzip_csv下载并解析数据
|
||||
df = download_unzip_csv(file_url, header=None, names=KLINE_COLUMNS)
|
||||
|
||||
# 添加symbol列
|
||||
df["symbol"] = symbol
|
||||
|
||||
# 转换时间戳为datetime
|
||||
df["open_time"] = pd.to_datetime(df["open_time"], unit='ms')
|
||||
df["close_time"] = pd.to_datetime(df["close_time"], unit='ms')
|
||||
|
||||
all_dfs.append(df)
|
||||
logger.info(f"Processed {os.path.basename(file_url)} with {len(df)} rows")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {file_url}: {e}")
|
||||
continue
|
||||
|
||||
if not all_dfs:
|
||||
logger.warning(f"No data processed for {symbol}-{interval}")
|
||||
return None
|
||||
|
||||
# 合并所有DataFrame
|
||||
merged_df = pd.concat(all_dfs, ignore_index=True)
|
||||
logger.info(f"Merged {len(all_dfs)} files into a single DataFrame with {len(merged_df)} rows")
|
||||
|
||||
# 去重
|
||||
merged_df = merged_df.drop_duplicates(subset=["symbol", "open_time"])
|
||||
logger.info(f"After deduplication, {len(merged_df)} rows remain")
|
||||
|
||||
# 插入到PostgreSQL数据库
|
||||
conn = create_connection()
|
||||
if conn:
|
||||
try:
|
||||
# 确保表存在
|
||||
create_table(conn)
|
||||
|
||||
# 插入数据
|
||||
insert_data(conn, merged_df)
|
||||
logger.info(f"Successfully inserted {len(merged_df)} rows into database for {symbol}")
|
||||
finally:
|
||||
# 关闭连接
|
||||
conn.close()
|
||||
|
||||
return merged_df
|
||||
|
||||
|
||||
def main():
|
||||
# 创建数据库连接
|
||||
conn = create_connection()
|
||||
@@ -282,4 +358,13 @@ def main():
|
||||
logger.info("Script completed successfully")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试新添加的process_symbol函数
|
||||
import sys
|
||||
if len(sys.argv) > 1:
|
||||
# 从命令行获取交易对
|
||||
symbol = sys.argv[1]
|
||||
interval = sys.argv[2] if len(sys.argv) > 2 else INTERVAL
|
||||
process_symbol(symbol, interval)
|
||||
else:
|
||||
# 默认运行main函数
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user