Files
mexc-spot-dca-bot/tool_csv_merge.py

130 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""CSV 文件处理模块
该模块用于合并同一证券在同一天的多笔交易记录,并生成汇总后的交易记录。
"""
import csv
import os
from collections import defaultdict
def process_csv(input_file, output_file):
"""处理CSV文件合并相同证券在同一天的交易记录
Args:
input_file (str): 输入CSV文件路径
output_file (str): 输出CSV文件路径
"""
merged_records = defaultdict(
lambda: {
"buy_shares": 0.0,
"buy_amount": 0.0,
"sell_shares": 0.0,
"sell_amount": 0.0,
"order_ids": set(),
"first_record": None,
"time_part": None,
}
)
with open(input_file, mode="r", newline="", encoding="utf-8") as infile:
reader = csv.DictReader(infile)
for row in reader:
datetime_str = row["日期"]
date_part, time_part = datetime_str.split("T")
order_id = row["备注"].split("Order ID: ")[-1].strip()
merge_key = (date_part, row["证券代码"])
record = merged_records[merge_key]
if record["first_record"] is None:
record["first_record"] = row
record["time_part"] = time_part
record["order_ids"].add(order_id)
if row["类型"] == "买入":
record["buy_shares"] += float(row["份额"])
record["buy_amount"] += float(row["净额"])
elif row["类型"] == "卖出":
record["sell_shares"] += float(row["份额"])
record["sell_amount"] += float(row["净额"])
output_rows = []
for key, record in merged_records.items():
date_part, symbol = key
first_row = record["first_record"]
net_shares = record["buy_shares"] - record["sell_shares"]
net_amount = record["buy_amount"] - record["sell_amount"]
if net_shares >= 0:
operation_type = "买入"
display_shares = net_shares
display_amount = net_amount
else:
operation_type = "卖出"
display_shares = -net_shares
display_amount = -net_amount
# 格式化为完整小数形式,不使用科学计数法
formatted_shares = (
f"{display_shares:f}".rstrip("0").rstrip(".")
if "." in f"{display_shares:f}"
else f"{display_shares:f}"
)
formatted_amount = (
f"{display_amount:f}".rstrip("0").rstrip(".")
if "." in f"{display_amount:f}"
else f"{display_amount:f}"
)
merged_row = {
"日期": f"{date_part}T{record['time_part']}",
"类型": operation_type,
"证券代码": symbol,
"份额": formatted_shares,
"净额": formatted_amount,
"现金账户": first_row["现金账户"],
"目标账户": first_row["目标账户"],
"备注": f"MEXC API - Order ID: {', '.join(sorted(record['order_ids']))}",
}
output_rows.append(merged_row)
# 写入输出文件
with open(output_file, mode="w", newline="", encoding="utf-8") as outfile:
fieldnames = [
"日期",
"类型",
"证券代码",
"份额",
"净额",
"现金账户",
"目标账户",
"备注",
]
writer = csv.DictWriter(outfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(output_rows)
def process_all_csvs(input_dir="output"):
"""处理指定目录下的所有CSV文件
Args:
input_dir (str): 包含CSV文件的目录路径
"""
for filename in os.listdir(input_dir):
if filename.endswith(".csv") and not filename.startswith("merged_"):
input_path = os.path.join(input_dir, filename)
output_path = os.path.join(input_dir, f"merged_{filename}")
process_csv(input_path, output_path)
print(f"处理完成: {filename} -> merged_{filename}")
if __name__ == "__main__":
process_all_csvs()
print("所有文件处理完成")