Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime
import functools
import hashlib
import importlib
import json
import logging
import os
Expand Down Expand Up @@ -1844,6 +1845,18 @@ async def _lifespan(app: FastAPI):
server.run()


def _load_lifespan_handler(lifespan_path: str) -> Any:
"""Dynamically import a lifespan handler from a string path."""
try:
module_name, func_name = lifespan_path.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, func_name)
except Exception as e:
raise click.ClickException(
f"Failed to load lifespan handler '{lifespan_path}': {e}"
) from e


@main.command("api_server")
@feature_options()
# The directory of agents, where each subdirectory is a single agent.
Expand All @@ -1866,6 +1879,15 @@ async def _lifespan(app: FastAPI):
"Automatically create a session if it doesn't exist when calling /run."
),
)
@click.option(
"--lifespan",
type=str,
default=None,
help=(
"Optional. The import path to a lifespan context manager (e.g.,"
" 'path.to.module.lifespan_handler')."
),
)
def cli_api_server(
agents_dir: str,
eval_storage_uri: Optional[str] = None,
Expand All @@ -1888,6 +1910,7 @@ def cli_api_server(
extra_plugins: Optional[list[str]] = None,
auto_create_session: bool = False,
trigger_sources: Optional[list[str]] = None,
lifespan: Optional[str] = None,
):
"""Starts a FastAPI server for agents.

Expand All @@ -1902,6 +1925,11 @@ def cli_api_server(
artifact_service_uri = artifact_service_uri or artifact_storage_uri
logs.setup_adk_logger(getattr(logging, log_level.upper()))

if agents_dir and agents_dir not in sys.path:
sys.path.insert(0, agents_dir)

lifespan_handler = _load_lifespan_handler(lifespan) if lifespan else None

config = uvicorn.Config(
get_fast_api_app(
agents_dir=agents_dir,
Expand All @@ -1922,6 +1950,7 @@ def cli_api_server(
extra_plugins=extra_plugins,
auto_create_session=auto_create_session,
trigger_sources=trigger_sources,
lifespan=lifespan_handler,
),
host=host,
port=port,
Expand Down
32 changes: 32 additions & 0 deletions tests/unittests/cli/utils/test_cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,38 @@ def test_cli_api_server_invokes_uvicorn(
assert _patch_uvicorn.calls, "uvicorn.Server.run must be called"


def test_cli_api_server_passes_lifespan(
tmp_path: Path, _patch_uvicorn: _Recorder, monkeypatch: pytest.MonkeyPatch
) -> None:
"""`adk api_server` should pass loaded lifespan handler to get_fast_api_app."""
agents_dir = tmp_path / "agents_api_lifespan"
agents_dir.mkdir()
lifespan_file = agents_dir / "dummy_lifespan.py"
lifespan_file.write_text("""
from contextlib import asynccontextmanager
@asynccontextmanager
async def dummy_handler(app):
yield
""")
mock_get_app = _Recorder()
monkeypatch.setattr(cli_tools_click, "get_fast_api_app", mock_get_app)
runner = CliRunner()
result = runner.invoke(
cli_tools_click.main,
[
"api_server",
str(agents_dir),
"--lifespan",
"dummy_lifespan.dummy_handler",
],
)
assert result.exit_code == 0, f"Output: {result.output}"
assert mock_get_app.calls
called_kwargs = mock_get_app.calls[0][1]
assert called_kwargs.get("lifespan") is not None
assert called_kwargs.get("lifespan").__name__ == "dummy_handler"


def test_cli_web_passes_service_uris(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, _patch_uvicorn: _Recorder
) -> None:
Expand Down