diff --git a/tutorial/01-environments.md b/tutorial/01-environments.md index 038331f6d..6531ad60f 100644 --- a/tutorial/01-environments.md +++ b/tutorial/01-environments.md @@ -245,8 +245,8 @@ Think of it like this: You don't run your database in the same process as your w │ │ └─────────────────┬──────────────────────────────────────────┘ │ - │ HTTP/JSON (Language-Agnostic) - │ POST /reset, POST /step, GET /state + │ OpenEnv protocol over WebSocket + │ (persistent session, language-agnostic) │ ┌─────────────────▼──────────────────────────────────────────┐ │ DOCKER CONTAINER │ @@ -262,15 +262,17 @@ Think of it like this: You don't run your database in the same process as your w ``` !!! info "Key Insight" - You never see HTTP details - just clean Python methods! + You never see the wire protocol - just clean Python methods! ```python - env.reset() # Under the hood: HTTP POST to /reset - env.step(...) # Under the hood: HTTP POST to /step - env.state() # Under the hood: HTTP GET to /state + await env.reset() # Under the hood: a reset message over WebSocket + await env.step(...) # Under the hood: a step message over WebSocket + await env.state() # Under the hood: a state message over WebSocket ``` - The magic? OpenEnv handles all the plumbing. You focus on RL! ✨ + In a notebook you can `await` directly; in a plain script, the same calls + block automatically. The magic? OpenEnv handles all the plumbing. You focus + on RL! ✨ --- @@ -332,7 +334,7 @@ src/envs/your_env/ │ (Action, Observation, State) │ ├── 📱 client.py ← What YOU import -│ (HTTPEnvClient implementation) +│ (EnvClient implementation) │ └── 🖥️ server/ ├── environment.py ← Game/simulation logic @@ -396,7 +398,7 @@ The client (`envs/openspiel_env/client.py`) inherits from `EnvClient` and declar ```python class OpenSpielEnv(EnvClient[OpenSpielAction, OpenSpielObservation, OpenSpielState]): def _step_payload(self, action: OpenSpielAction) -> dict: - """Convert typed action to JSON for HTTP.""" + """Convert typed action to JSON for the step message.""" return { "action_id": action.action_id, "game_name": action.game_name, @@ -404,15 +406,21 @@ class OpenSpielEnv(EnvClient[OpenSpielAction, OpenSpielObservation, OpenSpielSta } def _parse_result(self, payload: dict) -> StepResult: - """Parse HTTP JSON response into typed observation.""" + """Parse the server response into a typed observation.""" return StepResult( observation=OpenSpielObservation(...), reward=payload["reward"], done=payload["done"], ) + + def _parse_state(self, payload: dict) -> OpenSpielState: + """Parse the server response into a typed state.""" + return OpenSpielState(...) ``` -Usage is the same as any OpenEnv environment: +Usage is the same as any OpenEnv environment. In a plain script, these calls +block automatically (in a notebook you can `await` the same client directly +instead): ```python env = OpenSpielEnv(base_url="http://localhost:8000") @@ -433,7 +441,6 @@ from envs.openspiel_env.models import ( OpenSpielObservation, OpenSpielState ) -from dataclasses import fields print("="*70) print(" 🎮 OPENSPIEL INTEGRATION - TYPE-SAFE MODELS") @@ -441,18 +448,18 @@ print("="*70) print("\n📤 OpenSpielAction (what you send):") print(" " + "─" * 64) -for field in fields(OpenSpielAction): - print(f" • {field.name:20s} : {field.type}") +for name, field in OpenSpielAction.model_fields.items(): + print(f" • {name:20s} : {field.annotation}") print("\n📥 OpenSpielObservation (what you receive):") print(" " + "─" * 64) -for field in fields(OpenSpielObservation): - print(f" • {field.name:20s} : {field.type}") +for name, field in OpenSpielObservation.model_fields.items(): + print(f" • {name:20s} : {field.annotation}") print("\n📊 OpenSpielState (episode metadata):") print(" " + "─" * 64) -for field in fields(OpenSpielState): - print(f" • {field.name:20s} : {field.type}") +for name, field in OpenSpielState.model_fields.items(): + print(f" • {name:20s} : {field.annotation}") print("\n" + "="*70) print("\n💡 Type safety means:") @@ -507,13 +514,13 @@ print(" ✅ Self-documenting code\n") ### How the Client Works -The client **inherits from HTTPEnvClient** and implements 3 methods: +The client **inherits from `EnvClient`** and implements 3 methods: 1. `_step_payload()` - Convert action → JSON 2. `_parse_result()` - Parse JSON → typed observation 3. `_parse_state()` - Parse JSON → state -That's it! The base class handles all HTTP communication. +That's it! The base class handles all the async WebSocket communication. --- @@ -585,28 +592,27 @@ from envs.openspiel_env.models import ( OpenSpielObservation, OpenSpielState ) -from dataclasses import fields print("🎮 " + "="*64 + " 🎮") print(" ✅ Importing Real OpenSpiel Environment!") print("🎮 " + "="*64 + " 🎮\n") print("📦 What we just imported:") -print(" • OpenSpielEnv - HTTP client for OpenSpiel games") +print(" • OpenSpielEnv - EnvClient for OpenSpiel games") print(" • OpenSpielAction - Type-safe actions") print(" • OpenSpielObservation - Type-safe observations") print(" • OpenSpielState - Episode metadata\n") print("📋 OpenSpielObservation fields:") print(" " + "─" * 60) -for field in fields(OpenSpielObservation): - print(f" • {field.name:25s} : {field.type}") +for name, field in OpenSpielObservation.model_fields.items(): + print(f" • {name:25s} : {field.annotation}") print("\n" + "="*70) print("\n💡 This is REAL OpenEnv code - used in production!") print(" • Wraps 6 OpenSpiel games (Catch, Tic-Tac-Toe, Poker, etc.)") print(" • Type-safe actions and observations") -print(" • Works via HTTP (we'll see that next!)\n") +print(" • Talks to the server over WebSocket (we'll see that next!)\n") ``` **Output:** @@ -616,7 +622,7 @@ print(" • Works via HTTP (we'll see that next!)\n") 🎮 ================================================================ 🎮 📦 What we just imported: - • OpenSpielEnv - HTTP client for OpenSpiel games + • OpenSpielEnv - EnvClient for OpenSpiel games • OpenSpielAction - Type-safe actions • OpenSpielObservation - Type-safe observations • OpenSpielState - Episode metadata @@ -637,7 +643,7 @@ print(" • Works via HTTP (we'll see that next!)\n") 💡 This is REAL OpenEnv code - used in production! • Wraps 6 OpenSpiel games (Catch, Tic-Tac-Toe, Poker, etc.) • Type-safe actions and observations - • Works via HTTP (we'll see that next!) + • Talks to the server over WebSocket (we'll see that next!) ``` --- @@ -772,7 +778,7 @@ print(" • Work with ANY OpenSpiel game that exposes these!\n") Let's run **50 episodes** for each policy against **REAL OpenSpiel** and see who wins! -This is production code - every action is an HTTP call to the OpenSpiel server! +This is production code - every action is a WebSocket message to the OpenSpiel server! ```python def evaluate_policies(env, num_episodes=50): @@ -821,7 +827,7 @@ def evaluate_policies(env, num_episodes=50): print(" • Learning (~85%): Improves over time 📈") print("\n🎓 This is Reinforcement Learning + OpenEnv in action:") print(" 1. We USED existing OpenSpiel environment (didn't build it)") - print(" 2. Type-safe communication over HTTP") + print(" 2. Type-safe communication over WebSocket") print(" 3. Same code works for ANY OpenSpiel game") print(" 4. Production-ready architecture\n") @@ -842,10 +848,10 @@ In Parts 6-8, we **USED** the existing OpenSpiel Catch environment: |-------------|--------------| | **Imported** | OpenSpielEnv client (pre-built) | | **Started** | OpenSpiel server via uvicorn | -| **Connected** | HTTP client to server | +| **Connected** | Async client over WebSocket | | **Played** | Real OpenSpiel Catch game | -**🎯 This is production code!** Every action was an HTTP call to a real OpenSpiel environment. +**🎯 This is production code!** Every action was a WebSocket message to a real OpenSpiel environment. ### 🎮 6 Games Available - Same Interface! @@ -908,79 +914,88 @@ Want to wrap your own environment in OpenEnv? Here's how: ### Step 1: Define Types (`models.py`) +Actions, observations, and state are **Pydantic models**. The base classes already +provide the common fields — `Observation` has `done`, `reward`, `metadata`, and +`State` has `episode_id`, `step_count` — so you only add what's specific to your env. + ```python -from dataclasses import dataclass -from core.env_server import Action, Observation, State +from typing import List +from openenv.core.env_server import Action, Observation, State +from pydantic import Field -@dataclass class YourAction(Action): - action_value: int - # Add your action fields + action_value: int # add your action fields -@dataclass class YourObservation(Observation): - state_data: List[float] - done: bool - reward: float - # Add your observation fields + # `done`, `reward`, and `metadata` are inherited from Observation + state_data: List[float] # add your observation fields -@dataclass class YourState(State): - episode_id: str - step_count: int - # Add your state fields + # `episode_id` and `step_count` are inherited from State + score: int = 0 # add your state fields ``` ### Step 2: Implement Environment (`server/environment.py`) ```python -from core.env_server import Environment +from openenv.core.env_server import Environment class YourEnvironment(Environment): - def reset(self) -> Observation: + def reset(self, seed=None, episode_id=None, **kwargs) -> YourObservation: # Initialize your game/simulation return YourObservation(...) - - def step(self, action: Action) -> Observation: + + def step(self, action, **kwargs) -> YourObservation: # Execute action, update state return YourObservation(...) - + @property - def state(self) -> State: + def state(self) -> YourState: return self._state ``` ### Step 3: Create Client (`client.py`) +The client subclasses `EnvClient` and implements 3 hooks. Callers can `await` +`reset()`/`step()`/`state()` in async contexts, or call them directly from +synchronous scripts. + ```python -from core.http_env_client import HTTPEnvClient -from core.types import StepResult +from openenv.core.env_client import EnvClient +from openenv.core.client_types import StepResult -class YourEnv(HTTPEnvClient[YourAction, YourObservation]): +class YourEnv(EnvClient[YourAction, YourObservation, YourState]): def _step_payload(self, action: YourAction) -> dict: - """Convert action to JSON""" + """Convert action to the JSON step message.""" return {"action_value": action.action_value} - + def _parse_result(self, payload: dict) -> StepResult: - """Parse JSON to observation""" + """Parse the server response into a typed observation.""" return StepResult( observation=YourObservation(...), reward=payload['reward'], done=payload['done'] ) - + def _parse_state(self, payload: dict) -> YourState: return YourState(...) ``` ### Step 4: Create Server (`server/app.py`) +Pass the environment **class** (a factory) plus the action and observation types. +OpenEnv builds the WebSocket + HTTP endpoints for you. + ```python -from core.env_server import create_fastapi_app +from openenv.core.env_server import create_app from .your_environment import YourEnvironment -env = YourEnvironment() -app = create_fastapi_app(env) +app = create_app( + YourEnvironment, + YourAction, + YourObservation, + env_name="your_env", +) # That's it! OpenEnv creates all endpoints for you. ``` @@ -1040,7 +1055,7 @@ OpenEnv includes 3 complete examples: - Client-server separation - Type-safe contracts -- HTTP communication layer +- WebSocket communication layer ✅ **Production Patterns** @@ -1063,7 +1078,7 @@ OpenEnv includes 3 complete examples: - Define type-safe models - Implement Environment class -- Create HTTPEnvClient +- Subclass EnvClient ✅ **Testing & Debugging** diff --git a/tutorial/02-deployment.md b/tutorial/02-deployment.md index 7a7a9f616..59690a42b 100644 --- a/tutorial/02-deployment.md +++ b/tutorial/02-deployment.md @@ -33,8 +33,8 @@ async with EchoEnv(base_url="https://openenv-echo-env.hf.space") as client: result = await client.reset() result = await client.step(EchoAction(message="Hello")) -# Sync (using .sync() wrapper): -with EchoEnv(base_url="https://openenv-echo-env.hf.space").sync() as client: +# Sync: +with EchoEnv(base_url="https://openenv-echo-env.hf.space") as client: result = client.reset() result = client.step(EchoAction(message="Hello")) ``` @@ -132,8 +132,8 @@ async def main(): asyncio.run(main()) -# For sync usage, use the .sync() wrapper: -with EchoEnv(base_url="http://localhost:8001").sync() as client: +# Sync usage: +with EchoEnv(base_url="http://localhost:8001") as client: result = client.reset() ``` @@ -242,8 +242,8 @@ async def main(): asyncio.run(main()) -# Sync usage (using .sync() wrapper) -with EchoEnv(base_url="http://localhost:8000").sync() as client: +# Sync usage +with EchoEnv(base_url="http://localhost:8000") as client: result = client.reset() result = client.step(EchoAction(message="Hello")) print(result.observation) @@ -345,7 +345,7 @@ my_env/ │ ├── environment.py # Your environment logic │ └── Dockerfile ├── models.py # Action/Observation types -├── client.py # HTTP client +├── client.py # Client (EnvClient subclass) ├── openenv.yaml # Manifest └── pyproject.toml ``` @@ -418,8 +418,8 @@ async def main(): asyncio.run(main()) -# Sync (using .sync() wrapper) -with EchoEnv(base_url="http://localhost:7860").sync() as env: +# Sync +with EchoEnv(base_url="http://localhost:7860") as env: result = env.reset() print(result.observation) result = env.step(EchoAction(message="Hello")) diff --git a/tutorial/04-training.md b/tutorial/04-training.md index cc6c49338..a6c3b11b2 100644 --- a/tutorial/04-training.md +++ b/tutorial/04-training.md @@ -69,6 +69,8 @@ For more information, refer to the [TRL-OpenEnv documentation](https://huggingfa from envs.textarena_env import TextArenaEnv textarena_url = "https://burtenshaw-textarena.hf.space" # Duplicate the Space and update this! +# GRPOTrainer calls rollout_func synchronously; EnvClient calls block automatically +# in synchronous code. env = TextArenaEnv(base_url=textarena_url) ``` diff --git a/tutorial/examples/OpenEnv_Tutorial.ipynb b/tutorial/examples/OpenEnv_Tutorial.ipynb index 9771ae2e2..a136cf0e4 100644 --- a/tutorial/examples/OpenEnv_Tutorial.ipynb +++ b/tutorial/examples/OpenEnv_Tutorial.ipynb @@ -182,34 +182,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "🎲 ========================================================== 🎲\n", - " Number Guessing Game - The Simplest RL Example\n", - "🎲 ========================================================== 🎲\n", - "\n", - "🎯 I'm thinking of a number between 1 and 10...\n", - "💭 You have 3 guesses. Let's see how random guessing works!\n", - "\n", - "💭 Guess #1: 4 → ❄️ Cold! (far)\n", - "💭 Guess #2: 4 → ❄️ Cold! (far)\n", - "💭 Guess #3: 3 → ❄️ Cold! (far)\n", - "\n", - "💔 Out of guesses. The number was 8.\n", - "\n", - "==============================================================\n", - "💡 This is RL: Observe → Act → Reward → Repeat\n", - " But this policy is terrible! It doesn't learn from rewards.\n", - "==============================================================\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "import random\n", "\n", @@ -304,7 +279,7 @@ "\n", "\n", "\n", - "
\n", + "
\n", "\n", "## 💡 The OpenEnv Philosophy\n", "\n", @@ -338,8 +313,8 @@ "│ │\n", "└─────────────────┬──────────────────────────────────────────┘\n", " │\n", - " │ HTTP/JSON (Language-Agnostic)\n", - " │ POST /reset, POST /step, GET /state\n", + " │ OpenEnv protocol over WebSocket\n", + " │ (persistent session, language-agnostic)\n", " │\n", "┌─────────────────▼──────────────────────────────────────────┐\n", "│ DOCKER CONTAINER │\n", @@ -354,14 +329,14 @@ "└────────────────────────────────────────────────────────────┘\n", "```\n", "\n", - "
\n", + "
\n", "\n", - "**🎯 Key Insight**: You never see HTTP details - just clean Python methods!\n", + "**🎯 Key Insight**: You never see the wire protocol — just clean Python methods!\n", "\n", "```python\n", - "env.reset() # Under the hood: HTTP POST to /reset\n", - "env.step(...) # Under the hood: HTTP POST to /step\n", - "env.state() # Under the hood: HTTP GET to /state\n", + "await env.reset() # Under the hood: a reset message over WebSocket\n", + "await env.step(...) # Under the hood: a step message over WebSocket\n", + "await env.state() # Under the hood: a state message over WebSocket\n", "```\n", "\n", "The magic? OpenEnv handles all the plumbing. You focus on RL! ✨\n", @@ -388,23 +363,30 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "💻 Running locally - Nice!\n", - "✅ Using local OpenEnv installation\n", - "\n", - "🚀 Ready to explore OpenEnv and build amazing things!\n", - "💡 Tip: Run cells top-to-bottom for the best experience.\n", - "\n" - ] - } - ], + "outputs": [], "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "\n", + "def ensure_openenv_paths(start=None):\n", + " \"\"\"Add repo root + src to sys.path; return repo root. Safe to re-run.\"\"\"\n", + " start = Path(start or Path.cwd()).resolve()\n", + " for directory in (start, *start.parents):\n", + " if (directory / \"src\" / \"openenv\").is_dir() and (directory / \"envs\").is_dir():\n", + " for path in (directory / \"src\", directory):\n", + " path_str = str(path)\n", + " if path_str not in sys.path:\n", + " sys.path.insert(0, path_str)\n", + " return directory\n", + " raise RuntimeError(\n", + " \"Could not find OpenEnv repo root. \"\n", + " \"Run this notebook from inside the OpenEnv repository.\"\n", + " )\n", + "\n", + "\n", "# Detect environment\n", "try:\n", " import google.colab\n", @@ -418,21 +400,20 @@ " print(\"\\n📦 Cloning OpenEnv repository...\")\n", " !git clone https://github.com/huggingface/OpenEnv.git > /dev/null 2>&1\n", " %cd OpenEnv\n", - " \n", + "\n", " print(\"📚 Installing dependencies (this takes ~10 seconds)...\")\n", - " !pip install -q fastapi uvicorn requests\n", - " \n", - " import sys\n", - " sys.path.insert(0, './src')\n", + " !pip install -q fastapi uvicorn requests typing_extensions open_spiel\n", + "\n", + " OPENENV_ROOT = Path(\"/content/OpenEnv\")\n", + " ensure_openenv_paths(OPENENV_ROOT)\n", " print(\"\\n✅ Setup complete! Everything is ready to go! 🎉\")\n", "else:\n", - " import sys\n", - " from pathlib import Path\n", - " sys.path.insert(0, str(Path.cwd().parent))\n", + " OPENENV_ROOT = ensure_openenv_paths()\n", " print(\"✅ Using local OpenEnv installation\")\n", "\n", + "print(f\"📂 OpenEnv root: {OPENENV_ROOT}\")\n", "print(\"\\n🚀 Ready to explore OpenEnv and build amazing things!\")\n", - "print(\"💡 Tip: Run cells top-to-bottom for the best experience.\\n\")" + "print(\"💡 Tip: Re-run this cell after changing kernels, then run top-to-bottom.\\n\")" ] }, { @@ -444,7 +425,7 @@ "\n", "# Part 4: The OpenEnv Pattern 🏗️\n", "\n", - "
\n", + "
\n", "\n", "## Every OpenEnv Environment Has 3 Components:\n", "\n", @@ -454,7 +435,7 @@ "│ (Action, Observation, State)\n", "│\n", "├── 📱 client.py ← What YOU import\n", - "│ (HTTPEnvClient implementation)\n", + "│ (EnvClient implementation)\n", "│\n", "└── 🖥️ server/\n", " ├── environment.py ← Game/simulation logic\n", @@ -469,56 +450,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "======================================================================\n", - " 🧩 OPENENV CORE ABSTRACTIONS\n", - "======================================================================\n", - "\n", - "🖥️ SERVER SIDE (runs in Docker):\n", - "\n", - " class Environment(ABC):\n", - " '''Base class for all environment implementations'''\n", - "\n", - " @abstractmethod\n", - " def reset(self) -> Observation:\n", - " '''Start new episode'''\n", - "\n", - " @abstractmethod\n", - " def step(self, action: Action) -> Observation:\n", - " '''Execute action, return observation'''\n", - "\n", - " @property\n", - " def state(self) -> State:\n", - " '''Get episode metadata'''\n", - "\n", - "📱 CLIENT SIDE (your training code):\n", - "\n", - " class EnvClient(ABC):\n", - " '''Base class for HTTP clients'''\n", - "\n", - " def reset(self) -> StepResult:\n", - " # HTTP POST /reset\n", - "\n", - " def step(self, action) -> StepResult:\n", - " # HTTP POST /step\n", - "\n", - " def state(self) -> State:\n", - " # HTTP GET /state\n", - "\n", - "======================================================================\n", - "\n", - "✨ Same interface on both sides - communication via HTTP!\n", - "🎯 You focus on RL, OpenEnv handles the infrastructure.\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "# Import OpenEnv's core abstractions\n", "from openenv.core.env_server import Environment, Action, Observation, State\n", @@ -549,20 +483,20 @@ "📱 CLIENT SIDE (your training code):\n", "\n", " class EnvClient(ABC):\n", - " '''Base class for HTTP clients'''\n", + " '''Async base class for OpenEnv clients'''\n", " \n", - " def reset(self) -> StepResult:\n", - " # HTTP POST /reset\n", + " async def reset(self) -> StepResult:\n", + " # await client.reset()\n", " \n", - " def step(self, action) -> StepResult:\n", - " # HTTP POST /step\n", + " async def step(self, action) -> StepResult:\n", + " # await client.step(action)\n", " \n", - " def state(self) -> State:\n", - " # HTTP GET /state\n", + " async def state(self) -> State:\n", + " # await client.state()\n", "\"\"\")\n", "\n", "print(\"=\"*70)\n", - "print(\"\\n✨ Same interface on both sides - communication via HTTP!\")\n", + "print(\"\\n✨ Same interface on both sides - communication over WebSocket!\")\n", "print(\"🎯 You focus on RL, OpenEnv handles the infrastructure.\\n\")" ] }, @@ -574,7 +508,7 @@ "\n", "# Part 5: Example Integration - OpenSpiel 🎮\n", "\n", - "
\n", + "
\n", "\n", "## What is OpenSpiel?\n", "\n", @@ -612,58 +546,27 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "======================================================================\n", - " 🔌 HOW OPENENV WRAPS OPENSPIEL\n", - "======================================================================\n", - "\n", - "class OpenSpielEnv(HTTPEnvClient[OpenSpielAction, OpenSpielObservation]):\n", - "\n", - " def _step_payload(self, action: OpenSpielAction) -> dict:\n", - " '''Convert typed action to JSON for HTTP'''\n", - " return {\n", - " \"action_id\": action.action_id,\n", - " \"game_name\": action.game_name,\n", - " }\n", - "\n", - " def _parse_result(self, payload: dict) -> StepResult:\n", - " '''Parse HTTP JSON response into typed observation'''\n", - " return StepResult(\n", - " observation=OpenSpielObservation(...),\n", - " reward=payload['reward'],\n", - " done=payload['done']\n", - " )\n", - "\n", - "\n", - "──────────────────────────────────────────────────────────────────────\n", - "\n", - "✨ Usage (works for ALL OpenEnv environments):\n", - "\n", - " env = OpenSpielEnv(base_url=\"http://localhost:8000\")\n", - "\n", - " result = env.reset()\n", - " # Returns StepResult[OpenSpielObservation] - Type safe!\n", - "\n", - " result = env.step(OpenSpielAction(action_id=2, game_name=\"catch\"))\n", - " # Type checker knows this is valid!\n", - "\n", - " state = env.state()\n", - " # Returns OpenSpielState\n", - "\n", - "──────────────────────────────────────────────────────────────────────\n", - "\n", - "🎯 This pattern works for ANY environment you want to wrap!\n", - "\n" - ] - } - ], + "outputs": [], "source": [ + "if \"ensure_openenv_paths\" not in globals():\n", + " import sys\n", + " from pathlib import Path\n", + "\n", + " def ensure_openenv_paths(start=None):\n", + " start = Path(start or Path.cwd()).resolve()\n", + " for directory in (start, *start.parents):\n", + " if (directory / \"src\" / \"openenv\").is_dir() and (directory / \"envs\").is_dir():\n", + " for path in (directory / \"src\", directory):\n", + " path_str = str(path)\n", + " if path_str not in sys.path:\n", + " sys.path.insert(0, path_str)\n", + " return directory\n", + " raise RuntimeError(\"Run the Part 3 setup cell first.\")\n", + "\n", + "OPENENV_ROOT = ensure_openenv_paths()\n", + "\n", "from envs.openspiel_env.client import OpenSpielEnv\n", "\n", "print(\"=\"*70)\n", @@ -671,22 +574,26 @@ "print(\"=\"*70)\n", "\n", "print(\"\"\"\n", - "class OpenSpielEnv(HTTPEnvClient[OpenSpielAction, OpenSpielObservation]):\n", + "class OpenSpielEnv(EnvClient[OpenSpielAction, OpenSpielObservation, OpenSpielState]):\n", " \n", " def _step_payload(self, action: OpenSpielAction) -> dict:\n", - " '''Convert typed action to JSON for HTTP'''\n", + " '''Convert typed action to JSON for the step message'''\n", " return {\n", " \"action_id\": action.action_id,\n", " \"game_name\": action.game_name,\n", " }\n", " \n", " def _parse_result(self, payload: dict) -> StepResult:\n", - " '''Parse HTTP JSON response into typed observation'''\n", + " '''Parse the server response into a typed observation'''\n", " return StepResult(\n", " observation=OpenSpielObservation(...),\n", " reward=payload['reward'],\n", " done=payload['done']\n", " )\n", + " \n", + " def _parse_state(self, payload: dict) -> OpenSpielState:\n", + " '''Parse the server response into a typed state'''\n", + " return OpenSpielState(...)\n", "\n", "\"\"\")\n", "\n", @@ -695,13 +602,13 @@ "print(\"\"\"\n", " env = OpenSpielEnv(base_url=\"http://localhost:8000\")\n", " \n", - " result = env.reset()\n", + " result = await env.reset()\n", " # Returns StepResult[OpenSpielObservation] - Type safe!\n", " \n", - " result = env.step(OpenSpielAction(action_id=2, game_name=\"catch\"))\n", + " result = await env.step(OpenSpielAction(action_id=2, game_name=\"catch\"))\n", " # Type checker knows this is valid!\n", " \n", - " state = env.state()\n", + " state = await env.state()\n", " # Returns OpenSpielState\n", "\"\"\")\n", "\n", @@ -711,56 +618,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "======================================================================\n", - " 🎮 OPENSPIEL INTEGRATION - TYPE-SAFE MODELS\n", - "======================================================================\n", - "\n", - "📤 OpenSpielAction (what you send):\n", - " ────────────────────────────────────────────────────────────────\n", - " • metadata : typing.Dict[str, typing.Any]\n", - " • action_id : \n", - " • game_name : \n", - " • game_params : typing.Dict[str, typing.Any]\n", - "\n", - "📥 OpenSpielObservation (what you receive):\n", - " ────────────────────────────────────────────────────────────────\n", - " • done : \n", - " • reward : bool | int | float | None\n", - " • metadata : typing.Dict[str, typing.Any]\n", - " • info_state : typing.List[float]\n", - " • legal_actions : typing.List[int]\n", - " • game_phase : \n", - " • current_player_id : \n", - " • opponent_last_action : typing.Optional[int]\n", - "\n", - "📊 OpenSpielState (episode metadata):\n", - " ────────────────────────────────────────────────────────────────\n", - " • episode_id : typing.Optional[str]\n", - " • step_count : \n", - " • game_name : \n", - " • agent_player : \n", - " • opponent_policy : \n", - " • game_params : typing.Dict[str, typing.Any]\n", - " • num_players : \n", - "\n", - "======================================================================\n", - "\n", - "💡 Type safety means:\n", - " ✅ Your IDE autocompletes these fields\n", - " ✅ Typos are caught before running\n", - " ✅ Refactoring is safe\n", - " ✅ Self-documenting code\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "from envs.openspiel_env.models import (\n", " OpenSpielAction,\n", @@ -803,13 +663,13 @@ "\n", "
\n", "\n", - "The client **inherits from HTTPEnvClient** and implements 3 methods:\n", + "The client **inherits from `EnvClient`** and implements 3 methods:\n", "\n", "1. `_step_payload()` - Convert action → JSON\n", "2. `_parse_result()` - Parse JSON → typed observation \n", "3. `_parse_state()` - Parse JSON → state\n", "\n", - "That's it! The base class handles all HTTP communication.\n", + "That's it! The base class handles all the async WebSocket communication.\n", "\n", "
" ] @@ -901,45 +761,27 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "🎮 ================================================================ 🎮\n", - " ✅ Importing Real OpenSpiel Environment!\n", - "🎮 ================================================================ 🎮\n", - "\n", - "📦 What we just imported:\n", - " • OpenSpielEnv - HTTP client for OpenSpiel games\n", - " • OpenSpielAction - Type-safe actions\n", - " • OpenSpielObservation - Type-safe observations\n", - " • OpenSpielState - Episode metadata\n", - "\n", - "📋 OpenSpielObservation fields:\n", - " ────────────────────────────────────────────────────────────\n", - " • done : \n", - " • reward : bool | int | float | None\n", - " • metadata : typing.Dict[str, typing.Any]\n", - " • info_state : typing.List[float]\n", - " • legal_actions : typing.List[int]\n", - " • game_phase : \n", - " • current_player_id : \n", - " • opponent_last_action : typing.Optional[int]\n", - "\n", - "======================================================================\n", - "\n", - "💡 This is REAL OpenEnv code - used in production!\n", - " • Wraps 6 OpenSpiel games (Catch, Tic-Tac-Toe, Poker, etc.)\n", - " • Type-safe actions and observations\n", - " • Works via HTTP (we'll see that next!)\n", - "\n" - ] - } - ], + "outputs": [], "source": [ + "if \"ensure_openenv_paths\" not in globals():\n", + " import sys\n", + " from pathlib import Path\n", + "\n", + " def ensure_openenv_paths(start=None):\n", + " start = Path(start or Path.cwd()).resolve()\n", + " for directory in (start, *start.parents):\n", + " if (directory / \"src\" / \"openenv\").is_dir() and (directory / \"envs\").is_dir():\n", + " for path in (directory / \"src\", directory):\n", + " path_str = str(path)\n", + " if path_str not in sys.path:\n", + " sys.path.insert(0, path_str)\n", + " return directory\n", + " raise RuntimeError(\"Run the Part 3 setup cell first.\")\n", + "\n", + "OPENENV_ROOT = ensure_openenv_paths()\n", + "\n", "from envs.openspiel_env import OpenSpielEnv\n", "from envs.openspiel_env.models import (\n", " OpenSpielAction,\n", @@ -952,7 +794,7 @@ "print(\"🎮 \" + \"=\" * 64 + \" 🎮\\n\")\n", "\n", "print(\"📦 What we just imported:\")\n", - "print(\" • OpenSpielEnv - HTTP client for OpenSpiel games\")\n", + "print(\" • OpenSpielEnv - EnvClient for OpenSpiel games\")\n", "print(\" • OpenSpielAction - Type-safe actions\")\n", "print(\" • OpenSpielObservation - Type-safe observations\")\n", "print(\" • OpenSpielState - Episode metadata\\n\")\n", @@ -966,45 +808,32 @@ "print(\"\\n💡 This is REAL OpenEnv code - used in production!\")\n", "print(\" • Wraps 6 OpenSpiel games (Catch, Tic-Tac-Toe, Poker, etc.)\")\n", "print(\" • Type-safe actions and observations\")\n", - "print(\" • Works via HTTP (we'll see that next!)\\n\")" + "print(\" • Talks to the server over WebSocket (we'll see that next!)\\n\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "🚀 ================================================================ 🚀\n", - " Starting OpenSpiel Server (Catch Game)\n", - "🚀 ================================================================ 🚀\n", - "\n", - "✅ OpenSpiel is installed!\n", - "\n", - "⚡ Starting FastAPI server for OpenSpiel Catch...\n", - " (This uses REAL OpenEnv + OpenSpiel integration)\n", - "\n", - "⏳ Waiting for server to start...\n", - "\n", - "✅ OpenSpiel server is running!\n", - "🌐 Server URL: http://localhost:8000\n", - "📍 Endpoints available:\n", - " • POST /reset\n", - " • POST /step\n", - " • GET /state\n", - "\n", - "🎯 This is REAL OpenEnv + OpenSpiel in action!\n", - " • Running actual OpenSpiel Catch game\n", - " • Exposed via FastAPI HTTP server\n", - " • Using OpenEnv's standard interface\n", - "\n" - ] - } - ], + "outputs": [], "source": [ + "if \"ensure_openenv_paths\" not in globals():\n", + " import sys\n", + " from pathlib import Path\n", + "\n", + " def ensure_openenv_paths(start=None):\n", + " start = Path(start or Path.cwd()).resolve()\n", + " for directory in (start, *start.parents):\n", + " if (directory / \"src\" / \"openenv\").is_dir() and (directory / \"envs\").is_dir():\n", + " for path in (directory / \"src\", directory):\n", + " path_str = str(path)\n", + " if path_str not in sys.path:\n", + " sys.path.insert(0, path_str)\n", + " return directory\n", + " raise RuntimeError(\"Run the Part 3 setup cell first.\")\n", + "\n", + "OPENENV_ROOT = ensure_openenv_paths()\n", + "\n", "import subprocess\n", "import time\n", "import sys\n", @@ -1028,12 +857,8 @@ "print(\"⚡ Starting FastAPI server for OpenSpiel Catch...\")\n", "print(\" (This uses REAL OpenEnv + OpenSpiel integration)\\n\")\n", "\n", - "# Determine the correct path\n", - "if IN_COLAB:\n", - " work_dir = \"/content/OpenEnv\"\n", - "else:\n", - " from pathlib import Path\n", - " work_dir = str(Path.cwd().parent.absolute())\n", + "# Use repo root from the setup cell\n", + "work_dir = str(OPENENV_ROOT)\n", "\n", "server_process = subprocess.Popen(\n", " [sys.executable, \"-m\", \"uvicorn\",\n", @@ -1041,7 +866,7 @@ " \"--host\", \"0.0.0.0\",\n", " \"--port\", \"8000\"],\n", " env={**os.environ,\n", - " \"PYTHONPATH\": f\"{work_dir}/src\",\n", + " \"PYTHONPATH\": f\"{work_dir}/src:{work_dir}\",\n", " \"OPENSPIEL_GAME\": \"catch\",\n", " \"OPENSPIEL_AGENT_PLAYER\": \"0\",\n", " \"OPENSPIEL_OPPONENT_POLICY\": \"random\"},\n", @@ -1061,13 +886,12 @@ " response = requests.get('http://localhost:8000/health', timeout=2)\n", " print(\"\\n✅ OpenSpiel server is running!\")\n", " print(\"🌐 Server URL: http://localhost:8000\")\n", - " print(\"📍 Endpoints available:\")\n", - " print(\" • POST /reset\")\n", - " print(\" • POST /step\")\n", - " print(\" • GET /state\")\n", + " print(\"📍 The client connects over WebSocket:\")\n", + " print(\" • ws://localhost:8000/ws (persistent session)\")\n", + " print(\" • GET /health (HTTP health check)\")\n", " print(\"\\n🎯 This is REAL OpenEnv + OpenSpiel in action!\")\n", " print(\" • Running actual OpenSpiel Catch game\")\n", - " print(\" • Exposed via FastAPI HTTP server\")\n", + " print(\" • Exposed via FastAPI (WebSocket protocol)\")\n", " print(\" • Using OpenEnv's standard interface\\n\")\n", "except Exception as e:\n", " print(f\"\\n❌ Server failed to start: {e}\")\n", @@ -1084,100 +908,45 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "📱 ================================================================ 📱\n", - " Connecting to OpenSpiel Server via HTTP\n", - "📱 ================================================================ 📱\n", - "\n", - "✅ Client created!\n", - "\n", - "💡 What just happened:\n", - " • OpenSpielEnv is an HTTPEnvClient subclass\n", - " • It knows how to talk to OpenSpiel servers\n", - " • All communication is type-safe and over HTTP\n", - " • Same client works for ALL OpenSpiel games!\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(\"📱 \" + \"=\"*64 + \" 📱\")\n", - "print(\" Connecting to OpenSpiel Server via HTTP\")\n", + "print(\" Connecting to OpenSpiel Server over WebSocket\")\n", "print(\"📱 \" + \"=\"*64 + \" 📱\\n\")\n", "\n", - "# Create HTTP client for OpenSpiel\n", + "# Create the OpenSpiel client\n", "client = OpenSpielEnv(base_url=\"http://localhost:8000\")\n", "\n", "print(\"✅ Client created!\")\n", "print(\"\\n💡 What just happened:\")\n", - "print(\" • OpenSpielEnv is an HTTPEnvClient subclass\")\n", + "print(\" • OpenSpielEnv is an EnvClient subclass (async)\")\n", "print(\" • It knows how to talk to OpenSpiel servers\")\n", - "print(\" • All communication is type-safe and over HTTP\")\n", - "print(\" • Same client works for ALL OpenSpiel games!\\n\")" + "print(\" • All communication is type-safe and over WebSocket\")\n", + "print(\" • Same client works for ALL OpenSpiel games!\\n\")\n", + "\n", + "# 📝 EnvClient uses async WebSocket I/O, so notebook calls are awaited.\n", + "# This notebook relies on Jupyter's top-level `await`. Outside a\n", + "# notebook, plain scripts can call reset()/step()/state() directly;\n", + "# those calls block automatically.\n" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "🎮 ================================================================ 🎮\n", - " Testing Connection - Playing One Step\n", - "🎮 ================================================================ 🎮\n", - "\n", - "📤 Calling client.reset()...\n", - " Under the hood: HTTP POST to http://localhost:8000/reset\n", - "\n", - "📥 Received OpenSpielObservation:\n", - " • info_state: [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... (first 10 values)\n", - " • number of info_state: 50\n", - " • legal_actions: [0, 1, 2]\n", - " • game_phase: initial\n", - " • done: False\n", - "\n", - "📤 Calling client.step(OpenSpielAction(action_id=1, game_name='catch'))...\n", - " Under the hood: HTTP POST to http://localhost:8000/step\n", - "\n", - "📥 Received response:\n", - " • Reward: 0.0\n", - " • Done: False\n", - " • legal_actions: [0, 1, 2]\n", - "\n", - "📊 Episode state:\n", - " • episode_id: aec2e599-5633-45d6-a7a3-30f46e8796ab\n", - " • step_count: 1\n", - " • game_name: catch\n", - "\n", - "======================================================================\n", - "\n", - "🎉 IT WORKS! We're using REAL OpenSpiel via HTTP!\n", - " ✅ Type-safe communication\n", - " ✅ Same interface as any OpenEnv environment\n", - " ✅ Production-ready architecture\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(\"🎮 \" + \"=\"*64 + \" 🎮\")\n", "print(\" Testing Connection - Playing One Step\")\n", "print(\"🎮 \" + \"=\"*64 + \" 🎮\\n\")\n", "\n", - "# Reset the environment (HTTP POST /reset)\n", + "# Reset the environment (reset message over WebSocket)\n", "print(\"📤 Calling client.reset()...\")\n", - "print(\" Under the hood: HTTP POST to http://localhost:8000/reset\\n\")\n", + "print(\" Under the hood: a reset message over the WebSocket session\\n\")\n", "\n", - "result = client.reset()\n", + "result = await client.reset()\n", "\n", "print(\"📥 Received OpenSpielObservation:\")\n", "print(f\" • info_state: {result.observation.info_state[:10]}... (first 10 values)\")\n", @@ -1186,27 +955,27 @@ "print(f\" • game_phase: {result.observation.game_phase}\")\n", "print(f\" • done: {result.done}\")\n", "\n", - "# Take an action (HTTP POST /step)\n", + "# Take an action (step message over WebSocket)\n", "print(\"\\n📤 Calling client.step(OpenSpielAction(action_id=1, game_name=\\'catch\\'))...\")\n", - "print(\" Under the hood: HTTP POST to http://localhost:8000/step\\n\")\n", + "print(\" Under the hood: a step message over the WebSocket session\\n\")\n", "\n", "action = OpenSpielAction(action_id=1, game_name=\"catch\") # STAY\n", - "result = client.step(action)\n", + "result = await client.step(action)\n", "\n", "print(\"📥 Received response:\")\n", "print(f\" • Reward: {result.reward}\")\n", "print(f\" • Done: {result.done}\")\n", "print(f\" • legal_actions: {result.observation.legal_actions}\")\n", "\n", - "# Get state (HTTP GET /state)\n", - "state = client.state()\n", + "# Get state (state message over WebSocket)\n", + "state = await client.state()\n", "print(f\"\\n📊 Episode state:\")\n", "print(f\" • episode_id: {state.episode_id}\")\n", "print(f\" • step_count: {state.step_count}\")\n", "print(f\" • game_name: {state.game_name}\")\n", "\n", "print(\"\\n\" + \"=\"*70)\n", - "print(\"\\n🎉 IT WORKS! We\\'re using REAL OpenSpiel via HTTP!\")\n", + "print(\"\\n🎉 IT WORKS! We\\'re using REAL OpenSpiel over WebSocket!\")\n", "print(\" ✅ Type-safe communication\")\n", "print(\" ✅ Same interface as any OpenEnv environment\")\n", "print(\" ✅ Production-ready architecture\\n\")" @@ -1369,13 +1138,14 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", + "import asyncio\n", + "\n", "\n", - "def run_episode(env, policy, visualize=True, delay=0.3):\n", + "async def run_episode(env, policy, visualize=True, delay=0.3):\n", " \"\"\"Run one episode with a policy against OpenSpiel environment.\"\"\"\n", "\n", " # RESET\n", - " result = env.reset()\n", + " result = await env.reset()\n", " obs = result.observation\n", "\n", " if visualize:\n", @@ -1383,7 +1153,7 @@ " print(f\" 🎮 {policy.name}\")\n", " print(f\" 🎲 Playing against OpenSpiel Catch\")\n", " print('='*60 + '\\n')\n", - " time.sleep(delay)\n", + " await asyncio.sleep(delay)\n", "\n", " total_reward = 0\n", " step = 0\n", @@ -1394,9 +1164,9 @@ " # 1. Policy chooses action\n", " action_id = policy.select_action(obs)\n", "\n", - " # 2. Environment executes (via HTTP!)\n", + " # 2. Environment executes (via WebSocket!)\n", " action = OpenSpielAction(action_id=action_id, game_name=\"catch\")\n", - " result = env.step(action)\n", + " result = await env.step(action)\n", " obs = result.observation\n", "\n", " # 3. Collect reward\n", @@ -1405,7 +1175,7 @@ "\n", " if visualize:\n", " print(f\"📍 Step {step + 1}: {action_names[action_id]} → Reward: {result.reward}\")\n", - " time.sleep(delay)\n", + " await asyncio.sleep(delay)\n", "\n", " step += 1\n", "\n", @@ -1424,10 +1194,10 @@ "\n", "# Demo: Watch Smart Policy in action\n", "policy = SmartPolicy()\n", - "run_episode(client, policy, visualize=True, delay=0.5)\n", + "await run_episode(client, policy, visualize=True, delay=0.5)\n", "\n", "print(\"\\n💡 You just watched REAL OpenSpiel Catch being played!\")\n", - "print(\" • Every action was an HTTP call\")\n", + "print(\" • Every action was a WebSocket message to the server\")\n", "print(\" • Game logic runs in the server\")\n", "print(\" • Client only sends actions and receives observations\\n\")" ] @@ -1444,7 +1214,7 @@ "\n", "Let's run **50 episodes** for each policy against **REAL OpenSpiel** and see who wins!\n", "\n", - "This is production code - every action is an HTTP call to the OpenSpiel server!\n", + "This is production code - every action is a WebSocket message to the OpenSpiel server!\n", "\n", "
" ] @@ -1455,7 +1225,7 @@ "metadata": {}, "outputs": [], "source": [ - "def evaluate_policies(env, num_episodes=50):\n", + "async def evaluate_policies(env, num_episodes=50):\n", " \"\"\"Compare all policies over many episodes using real OpenSpiel.\"\"\"\n", " policies = [\n", " RandomPolicy(),\n", @@ -1472,8 +1242,9 @@ " results = []\n", " for policy in policies:\n", " print(f\"⚡ Testing {policy.name}...\", end=\" \")\n", - " successes = sum(run_episode(env, policy, visualize=False)\n", - " for _ in range(num_episodes))\n", + " successes = 0\n", + " for _ in range(num_episodes):\n", + " successes += await run_episode(env, policy, visualize=False)\n", " success_rate = (successes / num_episodes) * 100\n", " results.append((policy.name, success_rate, successes))\n", " print(f\"✓ Done!\")\n", @@ -1501,13 +1272,13 @@ " print(\" • Learning (~85%): Improves over time 📈\")\n", " print(\"\\n🎓 This is Reinforcement Learning + OpenEnv in action:\")\n", " print(\" 1. We USED existing OpenSpiel environment (didn\\'t build it)\")\n", - " print(\" 2. Type-safe communication over HTTP\")\n", + " print(\" 2. Type-safe communication over WebSocket\")\n", " print(\" 3. Same code works for ANY OpenSpiel game\")\n", " print(\" 4. Production-ready architecture\\n\")\n", "\n", "# Run the epic competition!\n", "print(\"🎮 Starting the showdown against REAL OpenSpiel...\\n\")\n", - "evaluate_policies(client, num_episodes=50)" + "await evaluate_policies(client, num_episodes=50)" ] }, { @@ -1540,7 +1311,7 @@ "\n", "\n", "Connected\n", - "HTTP client to server\n", + "Async client over WebSocket\n", "\n", "\n", "Played\n", @@ -1548,7 +1319,7 @@ "\n", "\n", "\n", - "**🎯 This is production code!** Every action was an HTTP call to a real OpenSpiel environment.\n", + "**🎯 This is production code!** Every action was a WebSocket message to a real OpenSpiel environment.\n", "\n", "
\n", "\n", @@ -1593,7 +1364,7 @@ " \"--host\", \"0.0.0.0\",\n", " \"--port\", \"8000\"],\n", " env={**os.environ,\n", - " \"PYTHONPATH\": f\"{work_dir}/src\",\n", + " \"PYTHONPATH\": f\"{work_dir}/src:{work_dir}\",\n", " \"OPENSPIEL_GAME\": \"tic_tac_toe\", # Changed!\n", " \"OPENSPIEL_AGENT_PLAYER\": \"0\",\n", " \"OPENSPIEL_OPPONENT_POLICY\": \"random\"},\n", @@ -1602,7 +1373,7 @@ "\n", "# Same client works!\n", "client = OpenSpielEnv(base_url=\"http://localhost:8000\")\n", - "result = client.reset() # Now playing Tic-Tac-Toe!\n", + "result = await client.reset() # in a notebook; plain scripts can call client.reset() directly\n", "```\n", "\n", "**💡 Key Insight**: You don't rebuild anything - you just USE different games with the same client!\n" @@ -1627,27 +1398,28 @@ "\n", "### Step 1: Define Types (`models.py`)\n", "\n", + "Actions, observations, and state are **Pydantic models**. The base classes already\n", + "provide the common fields — `Observation` has `done`, `reward`, `metadata`, and\n", + "`State` has `episode_id`, `step_count` — so you only add what's specific to your env.\n", + "\n", "```python\n", - "from dataclasses import dataclass\n", + "from typing import List\n", "from openenv.core.env_server import Action, Observation, State\n", + "from pydantic import Field\n", + "\n", "\n", - "@dataclass\n", "class YourAction(Action):\n", - " action_value: int\n", - " # Add your action fields\n", + " action_value: int # add your action fields\n", + "\n", "\n", - "@dataclass\n", "class YourObservation(Observation):\n", - " state_data: List[float]\n", - " done: bool\n", - " reward: float\n", - " # Add your observation fields\n", + " # `done`, `reward`, and `metadata` are inherited from Observation\n", + " state_data: List[float] # add your observation fields\n", + "\n", "\n", - "@dataclass\n", "class YourState(State):\n", - " episode_id: str\n", - " step_count: int\n", - " # Add your state fields\n", + " # `episode_id` and `step_count` are inherited from State\n", + " score: int = 0 # add your state fields\n", "```\n", "\n", "### Step 2: Implement Environment (`server/environment.py`)\n", @@ -1655,53 +1427,63 @@ "```python\n", "from openenv.core.env_server import Environment\n", "\n", + "\n", "class YourEnvironment(Environment):\n", - " def reset(self) -> Observation:\n", + " def reset(self, seed=None, episode_id=None, **kwargs) -> YourObservation:\n", " # Initialize your game/simulation\n", " return YourObservation(...)\n", - " \n", - " def step(self, action: Action) -> Observation:\n", + "\n", + " def step(self, action, **kwargs) -> YourObservation:\n", " # Execute action, update state\n", " return YourObservation(...)\n", - " \n", + "\n", " @property\n", - " def state(self) -> State:\n", + " def state(self) -> YourState:\n", " return self._state\n", "```\n", "\n", "### Step 3: Create Client (`client.py`)\n", "\n", + "The client subclasses `EnvClient` and implements 3 hooks. Callers\n", + "`await` `reset()`/`step()`/`state()` in async contexts, or call them directly from synchronous scripts.\n", + "\n", "```python\n", - "from openenv.core.http_env_client import HTTPEnvClient\n", - "from openenv.core.types import StepResult\n", + "from openenv.core.env_client import EnvClient\n", + "from openenv.core.client_types import StepResult\n", "\n", - "class YourEnv(HTTPEnvClient[YourAction, YourObservation]):\n", + "\n", + "class YourEnv(EnvClient[YourAction, YourObservation, YourState]):\n", " def _step_payload(self, action: YourAction) -> dict:\n", - " \"\"\"Convert action to JSON\"\"\"\n", + " \"\"\"Convert action to the JSON step message.\"\"\"\n", " return {\"action_value\": action.action_value}\n", - " \n", + "\n", " def _parse_result(self, payload: dict) -> StepResult:\n", - " \"\"\"Parse JSON to observation\"\"\"\n", + " \"\"\"Parse the server response into a typed observation.\"\"\"\n", " return StepResult(\n", " observation=YourObservation(...),\n", - " reward=payload['reward'],\n", - " done=payload['done']\n", + " reward=payload[\"reward\"],\n", + " done=payload[\"done\"],\n", " )\n", - " \n", + "\n", " def _parse_state(self, payload: dict) -> YourState:\n", " return YourState(...)\n", "```\n", "\n", "### Step 4: Create Server (`server/app.py`)\n", "\n", + "Pass the environment **class** (a factory) plus the action and observation types.\n", + "OpenEnv builds the WebSocket + HTTP endpoints for you.\n", + "\n", "```python\n", - "from openenv.core.env_server import create_fastapi_app\n", + "from openenv.core.env_server import create_app\n", "from .your_environment import YourEnvironment\n", "\n", - "env = YourEnvironment()\n", - "app = create_fastapi_app(env)\n", - "\n", - "# That's it! OpenEnv creates all endpoints for you.\n", + "app = create_app(\n", + " YourEnvironment,\n", + " YourAction,\n", + " YourObservation,\n", + " env_name=\"your_env\",\n", + ")\n", "```\n", "\n", "### Step 5: Dockerize (`server/Dockerfile`)\n", @@ -1739,7 +1521,7 @@ "\n", "**💡 Study these to understand the patterns!**\n", "\n", - "
" + "
\n" ] }, { @@ -1775,7 +1557,7 @@ "✅ **OpenEnv Architecture**\n", "- Client-server separation\n", "- Type-safe contracts\n", - "- HTTP communication layer\n", + "- WebSocket communication layer\n", "\n", "✅ **Production Patterns**\n", "- Docker isolation\n", @@ -1795,7 +1577,7 @@ "✅ **Building Environments**\n", "- Define type-safe models\n", "- Implement Environment class\n", - "- Create HTTPEnvClient\n", + "- Subclass EnvClient\n", "\n", "✅ **Testing & Debugging**\n", "- Compare policies\n", @@ -1900,15 +1682,15 @@ "### 🎓 Community & Support\n", "\n", "**Supported by amazing organizations:**\n", - "- 🔥 Meta PyTorch\n", - "- 🌟 Reflection AI\n", - "- ⚡ Unsloth AI\n", - "- ☁️ Modal\n", - "- 🧠 Prime Intellect\n", - "- 🟢 Nvidia\n", - "- 💼 Mercor\n", - "- 🚀 Fleet AI\n", - "- 🤗 Hugging Face\n", + "- 🔥 Meta PyTorch\n", + "- 🌟 Reflection AI\n", + "- ⚡ Unsloth AI\n", + "- ☁️ Modal\n", + "- 🧠 Prime Intellect\n", + "- 🟢 Nvidia\n", + "- 💼 Mercor\n", + "- 🚀 Fleet AI\n", + "- 🤗 Hugging Face\n", "- 🚀 And many more!\n", "\n", "**License**: BSD 3-Clause License\n", @@ -1931,7 +1713,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1945,7 +1727,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/tutorial/examples/unsloth_2048.ipynb b/tutorial/examples/unsloth_2048.ipynb index 1070da34c..53bb610bf 100644 --- a/tutorial/examples/unsloth_2048.ipynb +++ b/tutorial/examples/unsloth_2048.ipynb @@ -551,6 +551,8 @@ "OPENSPIEL_URL = \"https://openenv-openspiel-env.hf.space\"\n", "# For local: OPENSPIEL_URL = \"http://localhost:8000\"\n", "\n", + "# GRPOTrainer calls rollout_func synchronously; EnvClient calls block\n", + "# automatically in synchronous code.\n", "env = OpenSpielEnv(base_url=OPENSPIEL_URL)" ] },