81fbe028e8
- Load known bad IPs from FireHOL blocklists on startup - ~4400 IPs blocked by default - Set PUBLIC_BLOCKLIST=false to disable - Combined with manual BLOCKED_IPS env var
177 lines
5.9 KiB
Python
177 lines
5.9 KiB
Python
# middleware.py
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
import logging
|
|
import time
|
|
import uuid
|
|
import traceback
|
|
import httpx
|
|
import re
|
|
from collections import defaultdict
|
|
from typing import Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Public blocklist URLs
|
|
BLOCKLIST_URLS = [
|
|
"https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/firehol_level1.netset",
|
|
"https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/iblocklist_isp.netset",
|
|
]
|
|
|
|
|
|
def load_blocklist_from_url(url: str, timeout: int = 10) -> set[str]:
|
|
"""Download and parse IP blocklist from URL"""
|
|
ips = set()
|
|
try:
|
|
response = httpx.get(url, timeout=timeout, follow_redirects=True)
|
|
if response.status_code == 200:
|
|
for line in response.text.splitlines():
|
|
line = line.strip()
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
if re.match(r"^\d+\.\d+\.\d+\.\d+(/\d+)?$", line):
|
|
ip = line.split("/")[0]
|
|
ips.add(ip)
|
|
logger.info(f"Loaded {len(ips)} IPs from blocklist: {url}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load blocklist from {url}: {e}")
|
|
return ips
|
|
|
|
|
|
def load_public_blocklists() -> set[str]:
|
|
"""Load all public blocklists"""
|
|
all_ips = set()
|
|
for url in BLOCKLIST_URLS:
|
|
all_ips.update(load_blocklist_from_url(url))
|
|
logger.info(f"Total blocked IPs from public lists: {len(all_ips)}")
|
|
return all_ips
|
|
|
|
|
|
# Rate limiting config
|
|
RATE_LIMIT_REQUESTS = 60 # Max requests per window
|
|
RATE_LIMIT_WINDOW = 60 # Window in seconds
|
|
_ip_request_counts: dict[str, list[float]] = defaultdict(list)
|
|
|
|
# IP blocking config (set from main.py)
|
|
BLOCKED_IPS: set[str] = set()
|
|
|
|
# Suspicious paths that indicate bot scanning
|
|
SUSPICIOUS_PATHS = {
|
|
".env", ".env.local", ".env.production", ".env.development", ".env.bak",
|
|
".env.old", ".env.backup", ".env.orig", ".env.save", ".env~", ".env.swp",
|
|
".env.copy", ".env.1", ".ENV",
|
|
"appsettings.json", "appsettings.Development.json", "appsettings.Production.json",
|
|
"appsettings.Staging.json", "web.config",
|
|
"phpinfo.php", "info.php", "test.php", "i.php", "phpi.php", "php.php",
|
|
"phptest.php", "server-info.php", "phpinformation.php", "infophp.php",
|
|
"php_info.php", "config.php",
|
|
"actuator/env", "actuator/configprops", "actuator",
|
|
"manage/env", "admin/env", "env",
|
|
"actuator/env/aws", "actuator/env/cloud",
|
|
"_layouts/15/", "_layouts/15/ToolPane.aspx",
|
|
"swagger-ui", "api/docs", "openapi.json",
|
|
"wp-admin", "wp-login.php", "wordpress",
|
|
"administrator", "phpmyadmin",
|
|
".git", ".svn", ".hg",
|
|
}
|
|
|
|
def get_client_ip(request: Request) -> str:
|
|
"""Extract client IP from request"""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
client_ip = forwarded.split(",")[0].strip()
|
|
return client_ip
|
|
|
|
|
|
def is_ip_blocked(client_ip: str) -> bool:
|
|
"""Check if IP is blocked"""
|
|
return client_ip in BLOCKED_IPS
|
|
|
|
|
|
def check_rate_limit(client_ip: str) -> bool:
|
|
"""Check if IP has exceeded rate limit"""
|
|
now = time.time()
|
|
|
|
# Clean old requests
|
|
_ip_request_counts[client_ip] = [
|
|
t for t in _ip_request_counts[client_ip]
|
|
if now - t < RATE_LIMIT_WINDOW
|
|
]
|
|
|
|
if len(_ip_request_counts[client_ip]) >= RATE_LIMIT_REQUESTS:
|
|
return False
|
|
|
|
_ip_request_counts[client_ip].append(now)
|
|
return True
|
|
|
|
|
|
def is_suspicious_path(path: str) -> bool:
|
|
"""Check if path is suspicious (bot scanning)"""
|
|
path_lower = path.lower()
|
|
|
|
# Direct match
|
|
if path_lower in SUSPICIOUS_PATHS:
|
|
return True
|
|
|
|
# Contains suspicious patterns
|
|
suspicious_patterns = [
|
|
".env", "phpinfo", "actuator", "wp-", "phpmyadmin",
|
|
".git", ".svn", "swagger", "openapi",
|
|
]
|
|
|
|
for pattern in suspicious_patterns:
|
|
if pattern in path_lower:
|
|
return True
|
|
|
|
# Path traversal attempts
|
|
if ".." in path or ".." in path.replace("%2e%2e", "").replace("%252e", ""):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def set_ip_config(blocked: Optional[set[str]] = None):
|
|
"""Configure IP blocking (call from main.py)"""
|
|
global BLOCKED_IPS
|
|
if blocked is not None:
|
|
BLOCKED_IPS = blocked
|
|
|
|
|
|
class LoggingMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
request_id = str(uuid.uuid4())[:8]
|
|
client_ip = get_client_ip(request)
|
|
|
|
# Check if IP is blocked
|
|
if is_ip_blocked(client_ip):
|
|
return Response(status_code=404, content="")
|
|
|
|
# Check rate limit
|
|
if not check_rate_limit(client_ip):
|
|
logger.warning(f"Rate limited: {client_ip} ({request.url.path})")
|
|
return Response(status_code=429, content="Too many requests")
|
|
|
|
# Check suspicious path (silent 404 for bots)
|
|
path = request.url.path
|
|
if is_suspicious_path(path):
|
|
# Return 404 without logging - confuse the bots
|
|
return Response(status_code=404, content="")
|
|
|
|
# Log legitimate requests
|
|
start_time = time.time()
|
|
|
|
logger.info(f"→ {request.method} {path} (IP: {client_ip}, ID: {request_id})")
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
duration = (time.time() - start_time) * 1000
|
|
logger.info(f"← {request.method} {path} → {response.status_code} ({duration:.0f}ms) [ID: {request_id}]")
|
|
response.headers["X-Request-ID"] = request_id
|
|
return response
|
|
|
|
except Exception as e:
|
|
duration = (time.time() - start_time) * 1000
|
|
error_traceback = traceback.format_exc()
|
|
logger.error(f"✗ {request.method} {path} → ERROR: {str(e)} (ID: {request_id})\n{error_traceback}")
|
|
raise |