\n",
"\n",
"## Every OpenEnv Environment Has 3 Components:\n",
"\n",
@@ -454,7 +435,7 @@
"│ (Action, Observation, State)\n",
"│\n",
"├── 📱 client.py ← What YOU import\n",
- "│ (HTTPEnvClient implementation)\n",
+ "│ (EnvClient implementation)\n",
"│\n",
"└── 🖥️ server/\n",
" ├── environment.py ← Game/simulation logic\n",
@@ -469,56 +450,9 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "======================================================================\n",
- " 🧩 OPENENV CORE ABSTRACTIONS\n",
- "======================================================================\n",
- "\n",
- "🖥️ SERVER SIDE (runs in Docker):\n",
- "\n",
- " class Environment(ABC):\n",
- " '''Base class for all environment implementations'''\n",
- "\n",
- " @abstractmethod\n",
- " def reset(self) -> Observation:\n",
- " '''Start new episode'''\n",
- "\n",
- " @abstractmethod\n",
- " def step(self, action: Action) -> Observation:\n",
- " '''Execute action, return observation'''\n",
- "\n",
- " @property\n",
- " def state(self) -> State:\n",
- " '''Get episode metadata'''\n",
- "\n",
- "📱 CLIENT SIDE (your training code):\n",
- "\n",
- " class EnvClient(ABC):\n",
- " '''Base class for HTTP clients'''\n",
- "\n",
- " def reset(self) -> StepResult:\n",
- " # HTTP POST /reset\n",
- "\n",
- " def step(self, action) -> StepResult:\n",
- " # HTTP POST /step\n",
- "\n",
- " def state(self) -> State:\n",
- " # HTTP GET /state\n",
- "\n",
- "======================================================================\n",
- "\n",
- "✨ Same interface on both sides - communication via HTTP!\n",
- "🎯 You focus on RL, OpenEnv handles the infrastructure.\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# Import OpenEnv's core abstractions\n",
"from openenv.core.env_server import Environment, Action, Observation, State\n",
@@ -549,20 +483,20 @@
"📱 CLIENT SIDE (your training code):\n",
"\n",
" class EnvClient(ABC):\n",
- " '''Base class for HTTP clients'''\n",
+ " '''Async base class for OpenEnv clients'''\n",
" \n",
- " def reset(self) -> StepResult:\n",
- " # HTTP POST /reset\n",
+ " async def reset(self) -> StepResult:\n",
+ " # await client.reset()\n",
" \n",
- " def step(self, action) -> StepResult:\n",
- " # HTTP POST /step\n",
+ " async def step(self, action) -> StepResult:\n",
+ " # await client.step(action)\n",
" \n",
- " def state(self) -> State:\n",
- " # HTTP GET /state\n",
+ " async def state(self) -> State:\n",
+ " # await client.state()\n",
"\"\"\")\n",
"\n",
"print(\"=\"*70)\n",
- "print(\"\\n✨ Same interface on both sides - communication via HTTP!\")\n",
+ "print(\"\\n✨ Same interface on both sides - communication over WebSocket!\")\n",
"print(\"🎯 You focus on RL, OpenEnv handles the infrastructure.\\n\")"
]
},
@@ -574,7 +508,7 @@
"\n",
"# Part 5: Example Integration - OpenSpiel 🎮\n",
"\n",
- "
\n",
+ "
\n",
"\n",
"## What is OpenSpiel?\n",
"\n",
@@ -612,58 +546,27 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "======================================================================\n",
- " 🔌 HOW OPENENV WRAPS OPENSPIEL\n",
- "======================================================================\n",
- "\n",
- "class OpenSpielEnv(HTTPEnvClient[OpenSpielAction, OpenSpielObservation]):\n",
- "\n",
- " def _step_payload(self, action: OpenSpielAction) -> dict:\n",
- " '''Convert typed action to JSON for HTTP'''\n",
- " return {\n",
- " \"action_id\": action.action_id,\n",
- " \"game_name\": action.game_name,\n",
- " }\n",
- "\n",
- " def _parse_result(self, payload: dict) -> StepResult:\n",
- " '''Parse HTTP JSON response into typed observation'''\n",
- " return StepResult(\n",
- " observation=OpenSpielObservation(...),\n",
- " reward=payload['reward'],\n",
- " done=payload['done']\n",
- " )\n",
- "\n",
- "\n",
- "──────────────────────────────────────────────────────────────────────\n",
- "\n",
- "✨ Usage (works for ALL OpenEnv environments):\n",
- "\n",
- " env = OpenSpielEnv(base_url=\"http://localhost:8000\")\n",
- "\n",
- " result = env.reset()\n",
- " # Returns StepResult[OpenSpielObservation] - Type safe!\n",
- "\n",
- " result = env.step(OpenSpielAction(action_id=2, game_name=\"catch\"))\n",
- " # Type checker knows this is valid!\n",
- "\n",
- " state = env.state()\n",
- " # Returns OpenSpielState\n",
- "\n",
- "──────────────────────────────────────────────────────────────────────\n",
- "\n",
- "🎯 This pattern works for ANY environment you want to wrap!\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
+ "if \"ensure_openenv_paths\" not in globals():\n",
+ " import sys\n",
+ " from pathlib import Path\n",
+ "\n",
+ " def ensure_openenv_paths(start=None):\n",
+ " start = Path(start or Path.cwd()).resolve()\n",
+ " for directory in (start, *start.parents):\n",
+ " if (directory / \"src\" / \"openenv\").is_dir() and (directory / \"envs\").is_dir():\n",
+ " for path in (directory / \"src\", directory):\n",
+ " path_str = str(path)\n",
+ " if path_str not in sys.path:\n",
+ " sys.path.insert(0, path_str)\n",
+ " return directory\n",
+ " raise RuntimeError(\"Run the Part 3 setup cell first.\")\n",
+ "\n",
+ "OPENENV_ROOT = ensure_openenv_paths()\n",
+ "\n",
"from envs.openspiel_env.client import OpenSpielEnv\n",
"\n",
"print(\"=\"*70)\n",
@@ -671,22 +574,26 @@
"print(\"=\"*70)\n",
"\n",
"print(\"\"\"\n",
- "class OpenSpielEnv(HTTPEnvClient[OpenSpielAction, OpenSpielObservation]):\n",
+ "class OpenSpielEnv(EnvClient[OpenSpielAction, OpenSpielObservation, OpenSpielState]):\n",
" \n",
" def _step_payload(self, action: OpenSpielAction) -> dict:\n",
- " '''Convert typed action to JSON for HTTP'''\n",
+ " '''Convert typed action to JSON for the step message'''\n",
" return {\n",
" \"action_id\": action.action_id,\n",
" \"game_name\": action.game_name,\n",
" }\n",
" \n",
" def _parse_result(self, payload: dict) -> StepResult:\n",
- " '''Parse HTTP JSON response into typed observation'''\n",
+ " '''Parse the server response into a typed observation'''\n",
" return StepResult(\n",
" observation=OpenSpielObservation(...),\n",
" reward=payload['reward'],\n",
" done=payload['done']\n",
" )\n",
+ " \n",
+ " def _parse_state(self, payload: dict) -> OpenSpielState:\n",
+ " '''Parse the server response into a typed state'''\n",
+ " return OpenSpielState(...)\n",
"\n",
"\"\"\")\n",
"\n",
@@ -695,13 +602,13 @@
"print(\"\"\"\n",
" env = OpenSpielEnv(base_url=\"http://localhost:8000\")\n",
" \n",
- " result = env.reset()\n",
+ " result = await env.reset()\n",
" # Returns StepResult[OpenSpielObservation] - Type safe!\n",
" \n",
- " result = env.step(OpenSpielAction(action_id=2, game_name=\"catch\"))\n",
+ " result = await env.step(OpenSpielAction(action_id=2, game_name=\"catch\"))\n",
" # Type checker knows this is valid!\n",
" \n",
- " state = env.state()\n",
+ " state = await env.state()\n",
" # Returns OpenSpielState\n",
"\"\"\")\n",
"\n",
@@ -711,56 +618,9 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "======================================================================\n",
- " 🎮 OPENSPIEL INTEGRATION - TYPE-SAFE MODELS\n",
- "======================================================================\n",
- "\n",
- "📤 OpenSpielAction (what you send):\n",
- " ────────────────────────────────────────────────────────────────\n",
- " • metadata : typing.Dict[str, typing.Any]\n",
- " • action_id :
\n",
- " • game_name : \n",
- " • game_params : typing.Dict[str, typing.Any]\n",
- "\n",
- "📥 OpenSpielObservation (what you receive):\n",
- " ────────────────────────────────────────────────────────────────\n",
- " • done : \n",
- " • reward : bool | int | float | None\n",
- " • metadata : typing.Dict[str, typing.Any]\n",
- " • info_state : typing.List[float]\n",
- " • legal_actions : typing.List[int]\n",
- " • game_phase : \n",
- " • current_player_id : \n",
- " • opponent_last_action : typing.Optional[int]\n",
- "\n",
- "📊 OpenSpielState (episode metadata):\n",
- " ────────────────────────────────────────────────────────────────\n",
- " • episode_id : typing.Optional[str]\n",
- " • step_count : \n",
- " • game_name : \n",
- " • agent_player : \n",
- " • opponent_policy : \n",
- " • game_params : typing.Dict[str, typing.Any]\n",
- " • num_players : \n",
- "\n",
- "======================================================================\n",
- "\n",
- "💡 Type safety means:\n",
- " ✅ Your IDE autocompletes these fields\n",
- " ✅ Typos are caught before running\n",
- " ✅ Refactoring is safe\n",
- " ✅ Self-documenting code\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"from envs.openspiel_env.models import (\n",
" OpenSpielAction,\n",
@@ -803,13 +663,13 @@
"\n",
"\n",
"\n",
- "The client **inherits from HTTPEnvClient** and implements 3 methods:\n",
+ "The client **inherits from `EnvClient`** and implements 3 methods:\n",
"\n",
"1. `_step_payload()` - Convert action → JSON\n",
"2. `_parse_result()` - Parse JSON → typed observation \n",
"3. `_parse_state()` - Parse JSON → state\n",
"\n",
- "That's it! The base class handles all HTTP communication.\n",
+ "That's it! The base class handles all the async WebSocket communication.\n",
"\n",
"
"
]
@@ -901,45 +761,27 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "🎮 ================================================================ 🎮\n",
- " ✅ Importing Real OpenSpiel Environment!\n",
- "🎮 ================================================================ 🎮\n",
- "\n",
- "📦 What we just imported:\n",
- " • OpenSpielEnv - HTTP client for OpenSpiel games\n",
- " • OpenSpielAction - Type-safe actions\n",
- " • OpenSpielObservation - Type-safe observations\n",
- " • OpenSpielState - Episode metadata\n",
- "\n",
- "📋 OpenSpielObservation fields:\n",
- " ────────────────────────────────────────────────────────────\n",
- " • done : \n",
- " • reward : bool | int | float | None\n",
- " • metadata : typing.Dict[str, typing.Any]\n",
- " • info_state : typing.List[float]\n",
- " • legal_actions : typing.List[int]\n",
- " • game_phase : \n",
- " • current_player_id : \n",
- " • opponent_last_action : typing.Optional[int]\n",
- "\n",
- "======================================================================\n",
- "\n",
- "💡 This is REAL OpenEnv code - used in production!\n",
- " • Wraps 6 OpenSpiel games (Catch, Tic-Tac-Toe, Poker, etc.)\n",
- " • Type-safe actions and observations\n",
- " • Works via HTTP (we'll see that next!)\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
+ "if \"ensure_openenv_paths\" not in globals():\n",
+ " import sys\n",
+ " from pathlib import Path\n",
+ "\n",
+ " def ensure_openenv_paths(start=None):\n",
+ " start = Path(start or Path.cwd()).resolve()\n",
+ " for directory in (start, *start.parents):\n",
+ " if (directory / \"src\" / \"openenv\").is_dir() and (directory / \"envs\").is_dir():\n",
+ " for path in (directory / \"src\", directory):\n",
+ " path_str = str(path)\n",
+ " if path_str not in sys.path:\n",
+ " sys.path.insert(0, path_str)\n",
+ " return directory\n",
+ " raise RuntimeError(\"Run the Part 3 setup cell first.\")\n",
+ "\n",
+ "OPENENV_ROOT = ensure_openenv_paths()\n",
+ "\n",
"from envs.openspiel_env import OpenSpielEnv\n",
"from envs.openspiel_env.models import (\n",
" OpenSpielAction,\n",
@@ -952,7 +794,7 @@
"print(\"🎮 \" + \"=\" * 64 + \" 🎮\\n\")\n",
"\n",
"print(\"📦 What we just imported:\")\n",
- "print(\" • OpenSpielEnv - HTTP client for OpenSpiel games\")\n",
+ "print(\" • OpenSpielEnv - EnvClient for OpenSpiel games\")\n",
"print(\" • OpenSpielAction - Type-safe actions\")\n",
"print(\" • OpenSpielObservation - Type-safe observations\")\n",
"print(\" • OpenSpielState - Episode metadata\\n\")\n",
@@ -966,45 +808,32 @@
"print(\"\\n💡 This is REAL OpenEnv code - used in production!\")\n",
"print(\" • Wraps 6 OpenSpiel games (Catch, Tic-Tac-Toe, Poker, etc.)\")\n",
"print(\" • Type-safe actions and observations\")\n",
- "print(\" • Works via HTTP (we'll see that next!)\\n\")"
+ "print(\" • Talks to the server over WebSocket (we'll see that next!)\\n\")"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "🚀 ================================================================ 🚀\n",
- " Starting OpenSpiel Server (Catch Game)\n",
- "🚀 ================================================================ 🚀\n",
- "\n",
- "✅ OpenSpiel is installed!\n",
- "\n",
- "⚡ Starting FastAPI server for OpenSpiel Catch...\n",
- " (This uses REAL OpenEnv + OpenSpiel integration)\n",
- "\n",
- "⏳ Waiting for server to start...\n",
- "\n",
- "✅ OpenSpiel server is running!\n",
- "🌐 Server URL: http://localhost:8000\n",
- "📍 Endpoints available:\n",
- " • POST /reset\n",
- " • POST /step\n",
- " • GET /state\n",
- "\n",
- "🎯 This is REAL OpenEnv + OpenSpiel in action!\n",
- " • Running actual OpenSpiel Catch game\n",
- " • Exposed via FastAPI HTTP server\n",
- " • Using OpenEnv's standard interface\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
+ "if \"ensure_openenv_paths\" not in globals():\n",
+ " import sys\n",
+ " from pathlib import Path\n",
+ "\n",
+ " def ensure_openenv_paths(start=None):\n",
+ " start = Path(start or Path.cwd()).resolve()\n",
+ " for directory in (start, *start.parents):\n",
+ " if (directory / \"src\" / \"openenv\").is_dir() and (directory / \"envs\").is_dir():\n",
+ " for path in (directory / \"src\", directory):\n",
+ " path_str = str(path)\n",
+ " if path_str not in sys.path:\n",
+ " sys.path.insert(0, path_str)\n",
+ " return directory\n",
+ " raise RuntimeError(\"Run the Part 3 setup cell first.\")\n",
+ "\n",
+ "OPENENV_ROOT = ensure_openenv_paths()\n",
+ "\n",
"import subprocess\n",
"import time\n",
"import sys\n",
@@ -1028,12 +857,8 @@
"print(\"⚡ Starting FastAPI server for OpenSpiel Catch...\")\n",
"print(\" (This uses REAL OpenEnv + OpenSpiel integration)\\n\")\n",
"\n",
- "# Determine the correct path\n",
- "if IN_COLAB:\n",
- " work_dir = \"/content/OpenEnv\"\n",
- "else:\n",
- " from pathlib import Path\n",
- " work_dir = str(Path.cwd().parent.absolute())\n",
+ "# Use repo root from the setup cell\n",
+ "work_dir = str(OPENENV_ROOT)\n",
"\n",
"server_process = subprocess.Popen(\n",
" [sys.executable, \"-m\", \"uvicorn\",\n",
@@ -1041,7 +866,7 @@
" \"--host\", \"0.0.0.0\",\n",
" \"--port\", \"8000\"],\n",
" env={**os.environ,\n",
- " \"PYTHONPATH\": f\"{work_dir}/src\",\n",
+ " \"PYTHONPATH\": f\"{work_dir}/src:{work_dir}\",\n",
" \"OPENSPIEL_GAME\": \"catch\",\n",
" \"OPENSPIEL_AGENT_PLAYER\": \"0\",\n",
" \"OPENSPIEL_OPPONENT_POLICY\": \"random\"},\n",
@@ -1061,13 +886,12 @@
" response = requests.get('http://localhost:8000/health', timeout=2)\n",
" print(\"\\n✅ OpenSpiel server is running!\")\n",
" print(\"🌐 Server URL: http://localhost:8000\")\n",
- " print(\"📍 Endpoints available:\")\n",
- " print(\" • POST /reset\")\n",
- " print(\" • POST /step\")\n",
- " print(\" • GET /state\")\n",
+ " print(\"📍 The client connects over WebSocket:\")\n",
+ " print(\" • ws://localhost:8000/ws (persistent session)\")\n",
+ " print(\" • GET /health (HTTP health check)\")\n",
" print(\"\\n🎯 This is REAL OpenEnv + OpenSpiel in action!\")\n",
" print(\" • Running actual OpenSpiel Catch game\")\n",
- " print(\" • Exposed via FastAPI HTTP server\")\n",
+ " print(\" • Exposed via FastAPI (WebSocket protocol)\")\n",
" print(\" • Using OpenEnv's standard interface\\n\")\n",
"except Exception as e:\n",
" print(f\"\\n❌ Server failed to start: {e}\")\n",
@@ -1084,100 +908,45 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "📱 ================================================================ 📱\n",
- " Connecting to OpenSpiel Server via HTTP\n",
- "📱 ================================================================ 📱\n",
- "\n",
- "✅ Client created!\n",
- "\n",
- "💡 What just happened:\n",
- " • OpenSpielEnv is an HTTPEnvClient subclass\n",
- " • It knows how to talk to OpenSpiel servers\n",
- " • All communication is type-safe and over HTTP\n",
- " • Same client works for ALL OpenSpiel games!\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"print(\"📱 \" + \"=\"*64 + \" 📱\")\n",
- "print(\" Connecting to OpenSpiel Server via HTTP\")\n",
+ "print(\" Connecting to OpenSpiel Server over WebSocket\")\n",
"print(\"📱 \" + \"=\"*64 + \" 📱\\n\")\n",
"\n",
- "# Create HTTP client for OpenSpiel\n",
+ "# Create the OpenSpiel client\n",
"client = OpenSpielEnv(base_url=\"http://localhost:8000\")\n",
"\n",
"print(\"✅ Client created!\")\n",
"print(\"\\n💡 What just happened:\")\n",
- "print(\" • OpenSpielEnv is an HTTPEnvClient subclass\")\n",
+ "print(\" • OpenSpielEnv is an EnvClient subclass (async)\")\n",
"print(\" • It knows how to talk to OpenSpiel servers\")\n",
- "print(\" • All communication is type-safe and over HTTP\")\n",
- "print(\" • Same client works for ALL OpenSpiel games!\\n\")"
+ "print(\" • All communication is type-safe and over WebSocket\")\n",
+ "print(\" • Same client works for ALL OpenSpiel games!\\n\")\n",
+ "\n",
+ "# 📝 EnvClient uses async WebSocket I/O, so notebook calls are awaited.\n",
+ "# This notebook relies on Jupyter's top-level `await`. Outside a\n",
+ "# notebook, plain scripts can call reset()/step()/state() directly;\n",
+ "# those calls block automatically.\n"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "🎮 ================================================================ 🎮\n",
- " Testing Connection - Playing One Step\n",
- "🎮 ================================================================ 🎮\n",
- "\n",
- "📤 Calling client.reset()...\n",
- " Under the hood: HTTP POST to http://localhost:8000/reset\n",
- "\n",
- "📥 Received OpenSpielObservation:\n",
- " • info_state: [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... (first 10 values)\n",
- " • number of info_state: 50\n",
- " • legal_actions: [0, 1, 2]\n",
- " • game_phase: initial\n",
- " • done: False\n",
- "\n",
- "📤 Calling client.step(OpenSpielAction(action_id=1, game_name='catch'))...\n",
- " Under the hood: HTTP POST to http://localhost:8000/step\n",
- "\n",
- "📥 Received response:\n",
- " • Reward: 0.0\n",
- " • Done: False\n",
- " • legal_actions: [0, 1, 2]\n",
- "\n",
- "📊 Episode state:\n",
- " • episode_id: aec2e599-5633-45d6-a7a3-30f46e8796ab\n",
- " • step_count: 1\n",
- " • game_name: catch\n",
- "\n",
- "======================================================================\n",
- "\n",
- "🎉 IT WORKS! We're using REAL OpenSpiel via HTTP!\n",
- " ✅ Type-safe communication\n",
- " ✅ Same interface as any OpenEnv environment\n",
- " ✅ Production-ready architecture\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"print(\"🎮 \" + \"=\"*64 + \" 🎮\")\n",
"print(\" Testing Connection - Playing One Step\")\n",
"print(\"🎮 \" + \"=\"*64 + \" 🎮\\n\")\n",
"\n",
- "# Reset the environment (HTTP POST /reset)\n",
+ "# Reset the environment (reset message over WebSocket)\n",
"print(\"📤 Calling client.reset()...\")\n",
- "print(\" Under the hood: HTTP POST to http://localhost:8000/reset\\n\")\n",
+ "print(\" Under the hood: a reset message over the WebSocket session\\n\")\n",
"\n",
- "result = client.reset()\n",
+ "result = await client.reset()\n",
"\n",
"print(\"📥 Received OpenSpielObservation:\")\n",
"print(f\" • info_state: {result.observation.info_state[:10]}... (first 10 values)\")\n",
@@ -1186,27 +955,27 @@
"print(f\" • game_phase: {result.observation.game_phase}\")\n",
"print(f\" • done: {result.done}\")\n",
"\n",
- "# Take an action (HTTP POST /step)\n",
+ "# Take an action (step message over WebSocket)\n",
"print(\"\\n📤 Calling client.step(OpenSpielAction(action_id=1, game_name=\\'catch\\'))...\")\n",
- "print(\" Under the hood: HTTP POST to http://localhost:8000/step\\n\")\n",
+ "print(\" Under the hood: a step message over the WebSocket session\\n\")\n",
"\n",
"action = OpenSpielAction(action_id=1, game_name=\"catch\") # STAY\n",
- "result = client.step(action)\n",
+ "result = await client.step(action)\n",
"\n",
"print(\"📥 Received response:\")\n",
"print(f\" • Reward: {result.reward}\")\n",
"print(f\" • Done: {result.done}\")\n",
"print(f\" • legal_actions: {result.observation.legal_actions}\")\n",
"\n",
- "# Get state (HTTP GET /state)\n",
- "state = client.state()\n",
+ "# Get state (state message over WebSocket)\n",
+ "state = await client.state()\n",
"print(f\"\\n📊 Episode state:\")\n",
"print(f\" • episode_id: {state.episode_id}\")\n",
"print(f\" • step_count: {state.step_count}\")\n",
"print(f\" • game_name: {state.game_name}\")\n",
"\n",
"print(\"\\n\" + \"=\"*70)\n",
- "print(\"\\n🎉 IT WORKS! We\\'re using REAL OpenSpiel via HTTP!\")\n",
+ "print(\"\\n🎉 IT WORKS! We\\'re using REAL OpenSpiel over WebSocket!\")\n",
"print(\" ✅ Type-safe communication\")\n",
"print(\" ✅ Same interface as any OpenEnv environment\")\n",
"print(\" ✅ Production-ready architecture\\n\")"
@@ -1369,13 +1138,14 @@
"metadata": {},
"outputs": [],
"source": [
- "import time\n",
+ "import asyncio\n",
+ "\n",
"\n",
- "def run_episode(env, policy, visualize=True, delay=0.3):\n",
+ "async def run_episode(env, policy, visualize=True, delay=0.3):\n",
" \"\"\"Run one episode with a policy against OpenSpiel environment.\"\"\"\n",
"\n",
" # RESET\n",
- " result = env.reset()\n",
+ " result = await env.reset()\n",
" obs = result.observation\n",
"\n",
" if visualize:\n",
@@ -1383,7 +1153,7 @@
" print(f\" 🎮 {policy.name}\")\n",
" print(f\" 🎲 Playing against OpenSpiel Catch\")\n",
" print('='*60 + '\\n')\n",
- " time.sleep(delay)\n",
+ " await asyncio.sleep(delay)\n",
"\n",
" total_reward = 0\n",
" step = 0\n",
@@ -1394,9 +1164,9 @@
" # 1. Policy chooses action\n",
" action_id = policy.select_action(obs)\n",
"\n",
- " # 2. Environment executes (via HTTP!)\n",
+ " # 2. Environment executes (via WebSocket!)\n",
" action = OpenSpielAction(action_id=action_id, game_name=\"catch\")\n",
- " result = env.step(action)\n",
+ " result = await env.step(action)\n",
" obs = result.observation\n",
"\n",
" # 3. Collect reward\n",
@@ -1405,7 +1175,7 @@
"\n",
" if visualize:\n",
" print(f\"📍 Step {step + 1}: {action_names[action_id]} → Reward: {result.reward}\")\n",
- " time.sleep(delay)\n",
+ " await asyncio.sleep(delay)\n",
"\n",
" step += 1\n",
"\n",
@@ -1424,10 +1194,10 @@
"\n",
"# Demo: Watch Smart Policy in action\n",
"policy = SmartPolicy()\n",
- "run_episode(client, policy, visualize=True, delay=0.5)\n",
+ "await run_episode(client, policy, visualize=True, delay=0.5)\n",
"\n",
"print(\"\\n💡 You just watched REAL OpenSpiel Catch being played!\")\n",
- "print(\" • Every action was an HTTP call\")\n",
+ "print(\" • Every action was a WebSocket message to the server\")\n",
"print(\" • Game logic runs in the server\")\n",
"print(\" • Client only sends actions and receives observations\\n\")"
]
@@ -1444,7 +1214,7 @@
"\n",
"Let's run **50 episodes** for each policy against **REAL OpenSpiel** and see who wins!\n",
"\n",
- "This is production code - every action is an HTTP call to the OpenSpiel server!\n",
+ "This is production code - every action is a WebSocket message to the OpenSpiel server!\n",
"\n",
""
]
@@ -1455,7 +1225,7 @@
"metadata": {},
"outputs": [],
"source": [
- "def evaluate_policies(env, num_episodes=50):\n",
+ "async def evaluate_policies(env, num_episodes=50):\n",
" \"\"\"Compare all policies over many episodes using real OpenSpiel.\"\"\"\n",
" policies = [\n",
" RandomPolicy(),\n",
@@ -1472,8 +1242,9 @@
" results = []\n",
" for policy in policies:\n",
" print(f\"⚡ Testing {policy.name}...\", end=\" \")\n",
- " successes = sum(run_episode(env, policy, visualize=False)\n",
- " for _ in range(num_episodes))\n",
+ " successes = 0\n",
+ " for _ in range(num_episodes):\n",
+ " successes += await run_episode(env, policy, visualize=False)\n",
" success_rate = (successes / num_episodes) * 100\n",
" results.append((policy.name, success_rate, successes))\n",
" print(f\"✓ Done!\")\n",
@@ -1501,13 +1272,13 @@
" print(\" • Learning (~85%): Improves over time 📈\")\n",
" print(\"\\n🎓 This is Reinforcement Learning + OpenEnv in action:\")\n",
" print(\" 1. We USED existing OpenSpiel environment (didn\\'t build it)\")\n",
- " print(\" 2. Type-safe communication over HTTP\")\n",
+ " print(\" 2. Type-safe communication over WebSocket\")\n",
" print(\" 3. Same code works for ANY OpenSpiel game\")\n",
" print(\" 4. Production-ready architecture\\n\")\n",
"\n",
"# Run the epic competition!\n",
"print(\"🎮 Starting the showdown against REAL OpenSpiel...\\n\")\n",
- "evaluate_policies(client, num_episodes=50)"
+ "await evaluate_policies(client, num_episodes=50)"
]
},
{
@@ -1540,7 +1311,7 @@
"\n",
"
\n",
"| Connected | \n",
- "HTTP client to server | \n",
+ "Async client over WebSocket | \n",
"
\n",
"
\n",
"| Played | \n",
@@ -1548,7 +1319,7 @@
"
\n",
"\n",
"\n",
- "**🎯 This is production code!** Every action was an HTTP call to a real OpenSpiel environment.\n",
+ "**🎯 This is production code!** Every action was a WebSocket message to a real OpenSpiel environment.\n",
"\n",
"
\n",
"\n",
@@ -1593,7 +1364,7 @@
" \"--host\", \"0.0.0.0\",\n",
" \"--port\", \"8000\"],\n",
" env={**os.environ,\n",
- " \"PYTHONPATH\": f\"{work_dir}/src\",\n",
+ " \"PYTHONPATH\": f\"{work_dir}/src:{work_dir}\",\n",
" \"OPENSPIEL_GAME\": \"tic_tac_toe\", # Changed!\n",
" \"OPENSPIEL_AGENT_PLAYER\": \"0\",\n",
" \"OPENSPIEL_OPPONENT_POLICY\": \"random\"},\n",
@@ -1602,7 +1373,7 @@
"\n",
"# Same client works!\n",
"client = OpenSpielEnv(base_url=\"http://localhost:8000\")\n",
- "result = client.reset() # Now playing Tic-Tac-Toe!\n",
+ "result = await client.reset() # in a notebook; plain scripts can call client.reset() directly\n",
"```\n",
"\n",
"**💡 Key Insight**: You don't rebuild anything - you just USE different games with the same client!\n"
@@ -1627,27 +1398,28 @@
"\n",
"### Step 1: Define Types (`models.py`)\n",
"\n",
+ "Actions, observations, and state are **Pydantic models**. The base classes already\n",
+ "provide the common fields — `Observation` has `done`, `reward`, `metadata`, and\n",
+ "`State` has `episode_id`, `step_count` — so you only add what's specific to your env.\n",
+ "\n",
"```python\n",
- "from dataclasses import dataclass\n",
+ "from typing import List\n",
"from openenv.core.env_server import Action, Observation, State\n",
+ "from pydantic import Field\n",
+ "\n",
"\n",
- "@dataclass\n",
"class YourAction(Action):\n",
- " action_value: int\n",
- " # Add your action fields\n",
+ " action_value: int # add your action fields\n",
+ "\n",
"\n",
- "@dataclass\n",
"class YourObservation(Observation):\n",
- " state_data: List[float]\n",
- " done: bool\n",
- " reward: float\n",
- " # Add your observation fields\n",
+ " # `done`, `reward`, and `metadata` are inherited from Observation\n",
+ " state_data: List[float] # add your observation fields\n",
+ "\n",
"\n",
- "@dataclass\n",
"class YourState(State):\n",
- " episode_id: str\n",
- " step_count: int\n",
- " # Add your state fields\n",
+ " # `episode_id` and `step_count` are inherited from State\n",
+ " score: int = 0 # add your state fields\n",
"```\n",
"\n",
"### Step 2: Implement Environment (`server/environment.py`)\n",
@@ -1655,53 +1427,63 @@
"```python\n",
"from openenv.core.env_server import Environment\n",
"\n",
+ "\n",
"class YourEnvironment(Environment):\n",
- " def reset(self) -> Observation:\n",
+ " def reset(self, seed=None, episode_id=None, **kwargs) -> YourObservation:\n",
" # Initialize your game/simulation\n",
" return YourObservation(...)\n",
- " \n",
- " def step(self, action: Action) -> Observation:\n",
+ "\n",
+ " def step(self, action, **kwargs) -> YourObservation:\n",
" # Execute action, update state\n",
" return YourObservation(...)\n",
- " \n",
+ "\n",
" @property\n",
- " def state(self) -> State:\n",
+ " def state(self) -> YourState:\n",
" return self._state\n",
"```\n",
"\n",
"### Step 3: Create Client (`client.py`)\n",
"\n",
+ "The client subclasses `EnvClient` and implements 3 hooks. Callers\n",
+ "`await` `reset()`/`step()`/`state()` in async contexts, or call them directly from synchronous scripts.\n",
+ "\n",
"```python\n",
- "from openenv.core.http_env_client import HTTPEnvClient\n",
- "from openenv.core.types import StepResult\n",
+ "from openenv.core.env_client import EnvClient\n",
+ "from openenv.core.client_types import StepResult\n",
"\n",
- "class YourEnv(HTTPEnvClient[YourAction, YourObservation]):\n",
+ "\n",
+ "class YourEnv(EnvClient[YourAction, YourObservation, YourState]):\n",
" def _step_payload(self, action: YourAction) -> dict:\n",
- " \"\"\"Convert action to JSON\"\"\"\n",
+ " \"\"\"Convert action to the JSON step message.\"\"\"\n",
" return {\"action_value\": action.action_value}\n",
- " \n",
+ "\n",
" def _parse_result(self, payload: dict) -> StepResult:\n",
- " \"\"\"Parse JSON to observation\"\"\"\n",
+ " \"\"\"Parse the server response into a typed observation.\"\"\"\n",
" return StepResult(\n",
" observation=YourObservation(...),\n",
- " reward=payload['reward'],\n",
- " done=payload['done']\n",
+ " reward=payload[\"reward\"],\n",
+ " done=payload[\"done\"],\n",
" )\n",
- " \n",
+ "\n",
" def _parse_state(self, payload: dict) -> YourState:\n",
" return YourState(...)\n",
"```\n",
"\n",
"### Step 4: Create Server (`server/app.py`)\n",
"\n",
+ "Pass the environment **class** (a factory) plus the action and observation types.\n",
+ "OpenEnv builds the WebSocket + HTTP endpoints for you.\n",
+ "\n",
"```python\n",
- "from openenv.core.env_server import create_fastapi_app\n",
+ "from openenv.core.env_server import create_app\n",
"from .your_environment import YourEnvironment\n",
"\n",
- "env = YourEnvironment()\n",
- "app = create_fastapi_app(env)\n",
- "\n",
- "# That's it! OpenEnv creates all endpoints for you.\n",
+ "app = create_app(\n",
+ " YourEnvironment,\n",
+ " YourAction,\n",
+ " YourObservation,\n",
+ " env_name=\"your_env\",\n",
+ ")\n",
"```\n",
"\n",
"### Step 5: Dockerize (`server/Dockerfile`)\n",
@@ -1739,7 +1521,7 @@
"\n",
"**💡 Study these to understand the patterns!**\n",
"\n",
- "