fix: per-inbound propagate check for new inbounds on existing servers

This commit is contained in:
SashegDev
2026-05-22 05:45:38 +00:00
parent 4640262a14
commit dc92b75387
+45 -45
View File
@@ -1004,28 +1004,8 @@ def delete_3xui_client(username: str, sub_id: str, inbound: dict) -> dict:
except Exception as e: except Exception as e:
return {"success": False, "error": str(e)[:100]} return {"success": False, "error": str(e)[:100]}
def user_exists_on_server(sub_id: str, srv: dict) -> bool: def fetch_server_sub_ids(srv: dict) -> dict:
for inbound in srv.get("inbounds", []): result = {}
api_host = inbound.get("api_host")
api_user = inbound.get("api_user")
api_pass = inbound.get("api_pass")
inbound_id = inbound.get("id")
if not all([api_host, api_user, api_pass, inbound_id]):
continue
try:
api = Api(host=api_host, username=api_user, password=api_pass, use_tls_verify=False)
api.login()
for ib in api.inbound.get_list():
if ib.id == inbound_id and ib.client_stats:
for client in ib.client_stats:
if getattr(client, 'sub_id', '') == sub_id:
return True
except:
pass
return False
def fetch_server_sub_ids(srv: dict) -> set:
ids = set()
srv_name = srv.get("name", "?") srv_name = srv.get("name", "?")
for inbound in srv.get("inbounds", []): for inbound in srv.get("inbounds", []):
api_host = inbound.get("api_host") api_host = inbound.get("api_host")
@@ -1040,26 +1020,26 @@ def fetch_server_sub_ids(srv: dict) -> set:
def _fetch(): def _fetch():
api = Api(host=api_host, username=api_user, password=api_pass, use_tls_verify=False) api = Api(host=api_host, username=api_user, password=api_pass, use_tls_verify=False)
api.login() api.login()
result = set() ids = set()
for ib in api.inbound.get_list(): for ib in api.inbound.get_list():
if ib.id == inbound_id and ib.client_stats: if ib.id == inbound_id and ib.client_stats:
for client in ib.client_stats: for client in ib.client_stats:
sid = getattr(client, 'sub_id', '') or '' sid = getattr(client, 'sub_id', '') or ''
if sid: if sid:
result.add(sid) ids.add(sid)
return result return ids
with ThreadPoolExecutor(max_workers=1) as ex: with ThreadPoolExecutor(max_workers=1) as ex:
future = ex.submit(_fetch) future = ex.submit(_fetch)
try: try:
fetched = future.result(timeout=20) fetched = future.result(timeout=20)
ids.update(fetched) result[inbound_id] = fetched
logger.info(f"fetch_server_sub_ids {srv_name}/{ib_name}: got {len(fetched)} sub_ids") logger.info(f"fetch_server_sub_ids {srv_name}/{ib_name}: got {len(fetched)} sub_ids")
except FutureTimeout: except FutureTimeout:
logger.warning(f"Timeout fetching clients from {srv_name}/{ib_name}") logger.warning(f"Timeout fetching clients from {srv_name}/{ib_name}")
except Exception as e: except Exception as e:
logger.warning(f"fetch_server_sub_ids {srv_name}/{ib_name} error: {e}") logger.warning(f"fetch_server_sub_ids {srv_name}/{ib_name} error: {e}")
logger.info(f"fetch_server_sub_ids {srv_name}: total {len(ids)} sub_ids") logger.info(f"fetch_server_sub_ids {srv_name}: inbound count {len(result)}")
return ids return result
def propagate_server_sync(server_name: str) -> dict: def propagate_server_sync(server_name: str) -> dict:
target_srv = next((s for s in servers if s["name"] == server_name), None) target_srv = next((s for s in servers if s["name"] == server_name), None)
@@ -1072,38 +1052,58 @@ def propagate_server_sync(server_name: str) -> dict:
conn.close() conn.close()
other_servers = [s for s in servers if s.get("is_active", True) and s["name"] != server_name] other_servers = [s for s in servers if s.get("is_active", True) and s["name"] != server_name]
threshold = max(1, len(other_servers) // 2 + 1) threshold = max(1, len(other_servers) // 2 + 1)
target_ids = fetch_server_sub_ids(target_srv)
target_ib_ids = fetch_server_sub_ids(target_srv)
other_ids = {} other_ids = {}
for s in other_servers: for s in other_servers:
other_ids[s["name"]] = fetch_server_sub_ids(s) ib_dict = fetch_server_sub_ids(s)
flat = set()
for ids in ib_dict.values():
flat.update(ids)
other_ids[s["name"]] = flat
total = len(users) total = len(users)
already = added = skipped = failed = 0 added = skipped = failed = 0
inbound_stats = {}
for ib in target_srv.get("inbounds", []):
inbound_stats[ib.get("name", f"id_{ib.get('id','?')}")] = 0
results = [] results = []
for u in users: for u in users:
sub_id = u["subscription_id"] sub_id = u["subscription_id"]
username = u["username"] username = u["username"]
if sub_id in target_ids:
already += 1
continue
count = sum(1 for s in other_servers if sub_id in other_ids.get(s["name"], set())) count = sum(1 for s in other_servers if sub_id in other_ids.get(s["name"], set()))
if count >= threshold: if count < threshold:
ok = True skipped += 1
results.append({"username": username, "skipped": True, "servers": count, "total": len(other_servers)})
continue
user_added = False
for inbound in target_srv.get("inbounds", []): for inbound in target_srv.get("inbounds", []):
ib_id = inbound.get("id")
ib_name = inbound.get("name", "?")
existing = target_ib_ids.get(ib_id, set())
if sub_id in existing:
continue
r = create_3xui_client(username, sub_id, inbound, u["traffic_limit_gb"] or 0) r = create_3xui_client(username, sub_id, inbound, u["traffic_limit_gb"] or 0)
if not r.get("success"): if r.get("success"):
ok = False user_added = True
if ok: inbound_stats[ib_name] = inbound_stats.get(ib_name, 0) + 1
added += 1
clear_cache(sub_id) clear_cache(sub_id)
else: else:
failed += 1 failed += 1
results.append({"username": username, "added": ok}) results.append({"username": username, "inbound": ib_name, "error": r.get("error", "")})
else:
skipped += 1 if user_added:
results.append({"username": username, "skipped": True, "servers": count, "total": len(other_servers)}) added += 1
results.append({"username": username, "added": True})
already = sum(1 for u in users if u["subscription_id"] in {sid for ids in target_ib_ids.values() for sid in ids})
return {"server": server_name, "total": total, "already_on_server": already, return {"server": server_name, "total": total, "already_on_server": already,
"added": added, "skipped": skipped, "failed": failed, "added": added, "skipped": skipped, "failed": failed,
"threshold": f"{threshold}/{len(other_servers)}", "results": results} "threshold": f"{threshold}/{len(other_servers)}",
"per_inbound_added": inbound_stats, "results": results}
@app.post("/admin/api/propagate/{server_name}") @app.post("/admin/api/propagate/{server_name}")
async def propagate_server(request: Request, server_name: str): async def propagate_server(request: Request, server_name: str):