diff --git a/.env.example b/.env.example index ba69529..63b99f7 100644 --- a/.env.example +++ b/.env.example @@ -25,3 +25,7 @@ GODFORGE_ADMIN_PASSWORD=change_me_for_dashboard # Single-tenant routing — replace with your own Discord IDs before deploying GODFORGE_OWNER_USER_ID=0 # Discord user ID with elevated bot privileges (0 = disabled) GODFORGE_REPORTS_CHANNELS= # Comma-separated guild_id:channel_id pairs, e.g. 111:222,333:444 + +# Dashboard security — required for production +GODFORGE_ALLOWED_ORIGIN= # Origin allowed to call the API, e.g. https://your-app.railway.app (unset = reflects any origin) +GODFORGE_SESSION_SECRET= # Random secret for signing session tokens (falls back to insecure default if unset) diff --git a/bot.py b/bot.py index f25eaa0..bb88df8 100644 --- a/bot.py +++ b/bot.py @@ -71,6 +71,32 @@ _tracked_messages = {} _custom_command_cooldowns: dict[tuple[str, str, int], float] = {} +_ACTIVE_DRAFTS_FILE = os.path.join("data", "active_local_drafts.json") + + +def _load_active_drafts() -> dict: + try: + with open(_ACTIVE_DRAFTS_FILE, encoding="utf-8") as f: + return json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + return {} + + +def _save_active_draft(channel_id: int, draft_id: str): + data = _load_active_drafts() + data[str(channel_id)] = draft_id + os.makedirs("data", exist_ok=True) + with open(_ACTIVE_DRAFTS_FILE, "w", encoding="utf-8") as f: + json.dump(data, f) + + +def _remove_active_draft(channel_id: int): + data = _load_active_drafts() + if str(channel_id) in data: + data.pop(str(channel_id)) + with open(_ACTIVE_DRAFTS_FILE, "w", encoding="utf-8") as f: + json.dump(data, f) + # Activity backend draft tracking (in-memory, resets on restart). _match_ids: dict[int, str] = {} # channel_id -> match_id _match_channels: dict[str, int] = {} # match_id -> channel_id @@ -255,6 +281,24 @@ async def on_ready(): cleanup_task.start() log.info("Economy commands are deprecated in GodForge; ForgeLens owns betting and ledgers.") + orphaned = _load_active_drafts() + if orphaned: + try: + os.remove(_ACTIVE_DRAFTS_FILE) + except OSError: + pass + for channel_id_str in orphaned: + ch = client.get_channel(int(channel_id_str)) + if ch: + try: + await ch.send( + "⚠️ GodForge restarted — the active draft was lost. " + "Please start a new one with `.draft start`." + ) + except (discord.Forbidden, discord.HTTPException): + pass + log.info(f"Notified {len(orphaned)} channel(s) of lost draft(s) after restart") + @tasks.loop(minutes=5) async def cleanup_task(): @@ -271,6 +315,12 @@ async def cleanup_task(): log.info(f"Cleaned up {len(expired_drafts)} expired local draft(s)") +@cleanup_task.error +async def cleanup_task_error(exc: Exception): + log.error(f"cleanup_task crashed: {exc!r} — restarting") + cleanup_task.restart() + + @client.event async def on_message(message: discord.Message): if message.author == client.user or message.author.bot: @@ -566,6 +616,7 @@ async def _handle_draft_local(intent: dict, message: discord.Message): embed = formatter.format_draft_board(draft) sent = await message.channel.send(embed=embed) draft.board_message_id = sent.id + _save_active_draft(channel_id, draft.draft_id) log.info(f"Draft {draft.draft_id} started in channel {channel_id}: " f"🔵 {blue_user.display_name} vs 🔴 {red_user.display_name}") return None @@ -616,6 +667,7 @@ async def _handle_draft_local(intent: dict, message: discord.Message): log.info(f"Draft {draft.draft_id} report posted to reports channel") except (discord.Forbidden, discord.HTTPException) as e: log.warning(f"Failed to post to reports channel: {e}") + _remove_active_draft(channel_id) log.info(f"Draft {draft.draft_id} ended: {len(export['games'])} game(s)") return None diff --git a/requirements.txt b/requirements.txt index cf7ac28..3916188 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -discord.py>=2.3.0 -python-dotenv>=1.0.0 -aiohttp>=3.9.0 +discord.py==2.7.1 +python-dotenv==1.2.2 +aiohttp==3.13.5 + diff --git a/tests/integration/test_web_api_admin_bridge.py b/tests/integration/test_web_api_admin_bridge.py index d6ae7c8..6040e05 100644 --- a/tests/integration/test_web_api_admin_bridge.py +++ b/tests/integration/test_web_api_admin_bridge.py @@ -12,6 +12,7 @@ import urllib.error import urllib.parse import urllib.request +from collections import namedtuple from utils import ledger as ledger_utils from utils import wallet as wallet_utils @@ -19,6 +20,8 @@ from utils import dashboard_store from web_api import server as web_server +_Session = namedtuple("_Session", ["cookie", "csrf"]) + def _start_server(): httpd = web_server.create_server("127.0.0.1", 0) @@ -42,7 +45,10 @@ def _request(method, url, payload=None, cookie=None, follow_redirects=True): request = urllib.request.Request(url, data=data, method=method) if payload is not None: request.add_header("Content-Type", "application/json") - if cookie: + if isinstance(cookie, _Session): + request.add_header("Cookie", cookie.cookie) + request.add_header("X-CSRF-Token", cookie.csrf) + elif cookie: request.add_header("Cookie", cookie) opener = urllib.request.build_opener() if follow_redirects else urllib.request.build_opener(_NoRedirect) @@ -58,6 +64,17 @@ def _request(method, url, payload=None, cookie=None, follow_redirects=True): return exc.code, parsed, exc.headers +def _parse_set_cookies(headers) -> dict[str, str]: + """Extract all Set-Cookie values keyed by cookie name.""" + cookies = {} + for value in headers.get_all("Set-Cookie") or []: + name_val = value.split(";", 1)[0].strip() + if "=" in name_val: + name, val = name_val.split("=", 1) + cookies[name.strip()] = val.strip() + return cookies + + def _login(base): status, payload, headers = _request( "POST", @@ -66,7 +83,13 @@ def _login(base): ) assert status == 200 assert payload["authenticated"] is True - return headers["Set-Cookie"].split(";", 1)[0] + cookies = _parse_set_cookies(headers) + session_val = cookies[web_server.SESSION_COOKIE] + csrf_val = cookies[web_server.CSRF_COOKIE] + return _Session( + cookie=f"{web_server.SESSION_COOKIE}={session_val}; {web_server.CSRF_COOKIE}={csrf_val}", + csrf=csrf_val, + ) def test_public_health_and_tool_endpoints_do_not_require_auth(monkeypatch, tmp_ledger, tmp_wallets): diff --git a/web_api/server.py b/web_api/server.py index 39c07f2..4de2ce8 100644 --- a/web_api/server.py +++ b/web_api/server.py @@ -17,6 +17,7 @@ import mimetypes import os import re +import secrets import sys import time from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer @@ -41,6 +42,7 @@ } WEB_ROOT = ROOT / "web" SESSION_COOKIE = "godforge_admin" +CSRF_COOKIE = "godforge_csrf" OAUTH_STATE_COOKIE = "godforge_oauth_state" SESSION_MAX_AGE = 60 * 60 * 12 DISCORD_API_BASE = "https://discord.com/api/" @@ -267,12 +269,15 @@ def do_POST(self): try: body = self._read_body() + if parsed.path in PROTECTED_POST_PATHS: + if not self._require_auth(): + return + if not self._require_csrf(): + return if parsed.path == "/api/auth/login": self._login(body) elif parsed.path == "/api/auth/logout": self._logout() - elif parsed.path in PROTECTED_POST_PATHS and not self._require_auth(): - return elif parsed.path == "/api/command": self._send_json({"ok": True, "result": _execute_intent(body.get("message", ""))}) elif parsed.path == "/api/commands/custom": @@ -461,6 +466,7 @@ def _login(self, body: dict): return token = _sign_session(int(time.time()) + SESSION_MAX_AGE) + csrf_token = secrets.token_hex(32) payload = {"ok": True, "authenticated": True} data = json.dumps(payload).encode("utf-8") self.send_response(200) @@ -468,6 +474,7 @@ def _login(self, body: dict): self.send_header("Access-Control-Allow-Origin", _allowed_origin(self.headers.get("Origin", ""))) self.send_header("Access-Control-Allow-Credentials", "true") self.send_header("Set-Cookie", _cookie_header(token)) + self.send_header("Set-Cookie", _csrf_cookie_header(csrf_token)) self.send_header("Content-Length", str(len(data))) self.end_headers() self.wfile.write(data) @@ -524,8 +531,10 @@ def _discord_oauth_callback(self, query: dict): session = _sign_session(int(time.time()) + SESSION_MAX_AGE) self.send_response(302) + csrf_token = secrets.token_hex(32) self.send_header("Location", "/#dashboard?auth=discord") self.send_header("Set-Cookie", _cookie_header(session)) + self.send_header("Set-Cookie", _csrf_cookie_header(csrf_token)) self.send_header("Set-Cookie", f"{OAUTH_STATE_COOKIE}=; Path=/; Max-Age=0; HttpOnly; SameSite=Lax") self.send_header("Content-Length", "0") self.end_headers() @@ -546,6 +555,7 @@ def _logout(self): self.send_header("Access-Control-Allow-Origin", _allowed_origin(self.headers.get("Origin", ""))) self.send_header("Access-Control-Allow-Credentials", "true") self.send_header("Set-Cookie", f"{SESSION_COOKIE}=; Path=/; Max-Age=0; HttpOnly; SameSite=Lax") + self.send_header("Set-Cookie", f"{CSRF_COOKIE}=; Path=/; Max-Age=0; SameSite=Strict") self.send_header("Content-Length", str(len(data))) self.end_headers() self.wfile.write(data) @@ -559,6 +569,14 @@ def _require_auth(self) -> bool: self._send_error(401, "Admin login required.") return False + def _require_csrf(self) -> bool: + header_token = self.headers.get("X-CSRF-Token", "") + cookie_token = _cookies(self.headers.get("Cookie", "")).get(CSRF_COOKIE, "") + if header_token and cookie_token and hmac.compare_digest(header_token, cookie_token): + return True + self._send_error(403, "CSRF validation failed.") + return False + def log_message(self, format, *args): # noqa: A002 print(f"{self.address_string()} - {format % args}") @@ -1027,6 +1045,10 @@ def _cookie_header(token: str, name: str = SESSION_COOKIE, max_age: int = SESSIO return f"{name}={token}; Path=/; Max-Age={max_age}; HttpOnly; SameSite=Lax" +def _csrf_cookie_header(token: str) -> str: + return f"{CSRF_COOKIE}={token}; Path=/; Max-Age={SESSION_MAX_AGE}; SameSite=Strict" + + def _cookies(raw_cookie: str) -> dict[str, str]: cookies = {} for part in raw_cookie.split(";"):