feat(tool): CSV merge trades for the same pair on the same day

This commit is contained in:
2025-07-22 22:48:03 +08:00
parent c80e01a831
commit 3c250dc01b
2 changed files with 129 additions and 0 deletions

129
tool_csv_merge.py Normal file
View File

@@ -0,0 +1,129 @@
"""CSV 文件处理模块
该模块用于合并同一证券在同一天的多笔交易记录,并生成汇总后的交易记录。
"""
import csv
import os
from collections import defaultdict
def process_csv(input_file, output_file):
"""处理CSV文件合并相同证券在同一天的交易记录
Args:
input_file (str): 输入CSV文件路径
output_file (str): 输出CSV文件路径
"""
merged_records = defaultdict(
lambda: {
"buy_shares": 0.0,
"buy_amount": 0.0,
"sell_shares": 0.0,
"sell_amount": 0.0,
"order_ids": set(),
"first_record": None,
"time_part": None,
}
)
with open(input_file, mode="r", newline="", encoding="utf-8") as infile:
reader = csv.DictReader(infile)
for row in reader:
datetime_str = row["日期"]
date_part, time_part = datetime_str.split("T")
order_id = row["备注"].split("Order ID: ")[-1].strip()
merge_key = (date_part, row["证券代码"])
record = merged_records[merge_key]
if record["first_record"] is None:
record["first_record"] = row
record["time_part"] = time_part
record["order_ids"].add(order_id)
if row["类型"] == "买入":
record["buy_shares"] += float(row["份额"])
record["buy_amount"] += float(row["净额"])
elif row["类型"] == "卖出":
record["sell_shares"] += float(row["份额"])
record["sell_amount"] += float(row["净额"])
output_rows = []
for key, record in merged_records.items():
date_part, symbol = key
first_row = record["first_record"]
net_shares = record["buy_shares"] - record["sell_shares"]
net_amount = record["buy_amount"] - record["sell_amount"]
if net_shares >= 0:
operation_type = "买入"
display_shares = net_shares
display_amount = net_amount
else:
operation_type = "卖出"
display_shares = -net_shares
display_amount = -net_amount
# 格式化为完整小数形式,不使用科学计数法
formatted_shares = (
f"{display_shares:f}".rstrip("0").rstrip(".")
if "." in f"{display_shares:f}"
else f"{display_shares:f}"
)
formatted_amount = (
f"{display_amount:f}".rstrip("0").rstrip(".")
if "." in f"{display_amount:f}"
else f"{display_amount:f}"
)
merged_row = {
"日期": f"{date_part}T{record['time_part']}",
"类型": operation_type,
"证券代码": symbol,
"份额": formatted_shares,
"净额": formatted_amount,
"现金账户": first_row["现金账户"],
"目标账户": first_row["目标账户"],
"备注": f"MEXC API - Order ID: {', '.join(sorted(record['order_ids']))}",
}
output_rows.append(merged_row)
# 写入输出文件
with open(output_file, mode="w", newline="", encoding="utf-8") as outfile:
fieldnames = [
"日期",
"类型",
"证券代码",
"份额",
"净额",
"现金账户",
"目标账户",
"备注",
]
writer = csv.DictWriter(outfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(output_rows)
def process_all_csvs(input_dir="output"):
"""处理指定目录下的所有CSV文件
Args:
input_dir (str): 包含CSV文件的目录路径
"""
for filename in os.listdir(input_dir):
if filename.endswith(".csv") and not filename.startswith("merged_"):
input_path = os.path.join(input_dir, filename)
output_path = os.path.join(input_dir, f"merged_{filename}")
process_csv(input_path, output_path)
print(f"处理完成: {filename} -> merged_{filename}")
if __name__ == "__main__":
process_all_csvs()
print("所有文件处理完成")