From 110892b772d68a40ce90daddf9d935a5660e5d0b Mon Sep 17 00:00:00 2001 From: Zichao Lin Date: Sun, 8 Mar 2026 00:15:38 +0800 Subject: [PATCH] init --- .gitignore | 2 + .pylintrc | 4 + db.py | 53 +++++++++++++ main.py | 198 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 5 files changed, 258 insertions(+) create mode 100644 .gitignore create mode 100644 .pylintrc create mode 100644 db.py create mode 100644 main.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..468134a --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +output/ +config/ \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..a46c402 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,4 @@ +[MASTER] +disable= + W0123, + C0301, \ No newline at end of file diff --git a/db.py b/db.py new file mode 100644 index 0000000..8515e86 --- /dev/null +++ b/db.py @@ -0,0 +1,53 @@ +"""运行时数据库模块""" + +import json +import os + +class Database: + """数据库类""" + def __init__(self, config_name): + self.db_file = f'output/{config_name}_db.json' + self.load() + + def load(self): + """加载数据库""" + if not os.path.exists(self.db_file): + with open(self.db_file, 'w', encoding='utf-8') as f: + json.dump({}, f) + with open(self.db_file, 'r', encoding='utf-8') as f: + self.data = json.load(f) + + def save(self): + """保存数据库""" + with open(self.db_file, 'w', encoding='utf-8') as f: + json.dump(self.data, f, indent=4) + + def sync_trades(self, trades): + """ + 同步数据库中的trade_id,确保所有trade_id都存在,且没有多余的trade_id + trades: 当前config中有效的trade_id列表 + """ + self.load() + schema = { + "order_id": { + "open": [], + "filled": [], + "canceled": [] + }, + "counter": { + "open": 0, + "filled": 0, + "canceled": 0 + } + } + for db_trade_id in list(self.data.keys()): + if db_trade_id not in trades: + del self.data[db_trade_id] + for trade_id in trades: + if trade_id not in self.data: + self.data[trade_id] = schema + for trade_id in trades: + self.data[trade_id]["counter"]["open"] = len(self.data[trade_id]["order_id"]["open"]) + self.data[trade_id]["counter"]["filled"] = len(self.data[trade_id]["order_id"]["filled"]) + self.data[trade_id]["counter"]["canceled"] = len(self.data[trade_id]["order_id"]["canceled"]) + self.save() diff --git a/main.py b/main.py new file mode 100644 index 0000000..9d45c96 --- /dev/null +++ b/main.py @@ -0,0 +1,198 @@ +"""条件订单机器人""" + +import os +import csv +import json +import time +import argparse +import logging +from datetime import datetime, timedelta, timezone + +import ccxt + +from db import Database + +def config_load(file_path): + """加载配置文件""" + if not os.path.exists(file_path): + raise FileNotFoundError(file_path) + with open(file_path, "r", encoding="utf-8") as f: + config = json.load(f) + return config + + +def config_validate(config): + """验证配置文件有效性""" + required_fields = ["interval", "trades", "accounts"] + + for field in required_fields: + if field not in config: + raise ValueError(f"缺少必要字段 {field}") + + if not isinstance(config["trades"], dict): + raise ValueError("字段 'trades' 必须是字典") + for trade_id, trade_data in config["trades"].items(): + required_keys = ["account", "condition_place", "condition_cancel", "order"] + for key in required_keys: + if key not in trade_data: + raise ValueError(f"交易 {trade_id} 配置不完整,必须包含 '{key}'") + if "amount" not in trade_data["order"] and "cost" not in trade_data["order"]: + raise ValueError(f"交易 {trade_id} 的订单配置必须包含 'amount' 或 'cost'") + + for account_id, account_data in config["accounts"].items(): + required_keys = ["exchange", "api_key", "secret_key"] + for key in required_keys: + if key not in account_data: + raise ValueError(f"账户 {account_id} 配置不完整,必须包含 '{key}'") + if account_data["exchange"] not in ccxt.exchanges: + raise ValueError(f"账户 {account_id} 的交易所 {account_data['exchange']} 不受支持") + +def csv_record(file_name, order, exchange, csv_symbol): + """记录订单信息到csv""" + file_path = os.path.join("output", file_name+".csv") + csv_data = [ + # TODO 先用UTC+8,后续可以考虑让用户自定义时区 + (datetime.fromtimestamp(order["timestamp"] / 1000, timezone.utc) + timedelta(hours=8)).strftime("%Y-%m-%dT%H:%M"), + "买入" if order["side"] == "buy" else "卖出", + csv_symbol if csv_symbol else order["symbol"], + order["amount"], + order["cost"], + f"ccxt - {exchange} - 订单ID: {order['id']}" + ] + if not os.path.exists(file_path): + with open(file_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["日期", "类型", "证券代码", "份额", "净额", "备注"]) + with open(file_path, "a", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(csv_data) + +def main(): + """主函数""" + + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--config", "-c", help="配置文件路径", required=True) + args = arg_parser.parse_args() + config_file_path = args.config + config_file_name = os.path.splitext(os.path.basename(config_file_path))[0] + + os.makedirs("output", exist_ok=True) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(f"output/{config_file_name}.log", encoding="utf-8"), + logging.StreamHandler(), + ], + ) + logger = logging.getLogger("main") + + db = Database(config_file_name) + next_exec_time = time.time() + config = config_load(config_file_path) + + while True: + logger.info("开始检查") + try: + old_config = config + config = config_load(config_file_path) + if config != old_config: + logger.warning("配置文件已更新") + config_validate(config) + except FileNotFoundError as e: + logger.error("配置文件不存在: %s", e) + return + except ValueError as e: + logger.error("配置文件验证失败: %s", e) + return + db.sync_trades(config["trades"].keys()) + accounts = {} + for account_id, account_data in config["accounts"].items(): + accounts[account_id] = getattr(ccxt, account_data["exchange"])( + { + "apiKey": account_data["api_key"], + "secret": account_data["secret_key"], + } + ) + if account_data.get("demo", False): + accounts[account_id].enable_demo_trading(True) + + for trade_id, trade_data in config["trades"].items(): + if trade_data["account"] not in accounts: + logger.error("交易 %s 配置的账户 %s 在 accounts 中未找到", trade_id, trade_data["account"]) + continue + # 更新open订单情况 + for order_id in db.data[trade_id]["order_id"]["open"]: + order_detail = accounts[trade_data["account"]].fetch_order(order_id, trade_data["order"]["symbol"]) + if order_detail["status"] == "closed": + # 完全成交 + logger.info("%s: 订单成交", order_id) + csv_record(config_file_name, order_detail, accounts[trade_data["account"]].id, trade_data.get("csv_symbol", None)) + db.data[trade_id]["order_id"]["open"].remove(order_id) + db.data[trade_id]["order_id"]["filled"].append(order_id) + db.data[trade_id]["counter"]["open"] -= 1 + db.data[trade_id]["counter"]["filled"] += 1 + elif order_detail["status"] == "canceled": + # 可能是用户手动取消了,那么放弃追踪 + logger.info("%s: 订单被取消,放弃追踪", order_id) + db.data[trade_id]["order_id"]["open"].remove(order_id) + db.data[trade_id]["counter"]["open"] -= 1 + elif order_detail["status"] == "open": + logger.info("%s: 订单未(完全)成交", order_id) + else: + logger.warning("%s: 订单状态为 %s,无法识别", order_id, order_detail["status"]) + + # 判断是否要取消订单 + if db.data[trade_id]["counter"]["open"] > 0: + if eval(trade_data["condition_cancel"]): + logger.info("交易 %s 符合取消条件,正在取消订单", trade_id) + accounts[trade_data["account"]].cancel_orders(db.data[trade_id]["order_id"]["open"]) + for order_id in db.data[trade_id]["order_id"]["open"]: + db.data[trade_id]["order_id"]["open"].remove(order_id) + db.data[trade_id]["order_id"]["canceled"].append(order_id) + db.data[trade_id]["counter"]["open"] -= 1 + db.data[trade_id]["counter"]["canceled"] += 1 + + # 判断是否要下单 + if eval(trade_data["condition_place"]): + logger.info("交易 %s 符合下单条件,正在下单", trade_id) + if trade_data["order"]["type"] == "market" and not accounts[trade_data["account"]].has["createMarketOrder"]: + logger.error("账户 %s 的交易所 %s 不支持市价单", trade_data["account"], accounts[trade_data["account"]].id) + return + + price = accounts[trade_data["account"]].fetch_ticker(trade_data["order"]["symbol"])["close"] + if "amount" not in trade_data["order"]: + # 只有cost没有amount,根据价格算amount + trade_data["order"]["amount"] = trade_data["order"]["cost"] / price + if "price" not in trade_data["order"]: + # 没有price,则市价 + if trade_data["order"]["side"] == "buy": + trade_data["order"]["price"] = price * 1.02 + else: + trade_data["order"]["price"] = price * 0.98 + + order_detail = accounts[trade_data["account"]].create_order( + trade_data["order"]["symbol"], + trade_data["order"]["type"], + trade_data["order"]["side"], + trade_data["order"]["amount"], + trade_data["order"].get("price", None), + ) + if order_detail["status"] == "open": + db.data[trade_id]["order_id"]["open"].append(order_detail["id"]) + db.data[trade_id]["counter"]["open"] += 1 + elif order_detail["status"] == "closed": + logger.info("%s: 订单成交", order_detail["id"]) + csv_record(config_file_name, order_detail, accounts[trade_data["account"]].id, trade_data.get("csv_symbol", None)) + db.data[trade_id]["order_id"]["filled"].append(order_detail["id"]) + db.data[trade_id]["counter"]["filled"] += 1 + else: + logger.warning("%s: 订单状态为 %s,未加入追踪", order_detail["id"], order_detail["status"]) + + db.save() + next_exec_time += config["interval"] + time.sleep(max(0, next_exec_time - time.time())) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..51eec27 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +ccxt \ No newline at end of file