Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions src/strands_tools/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
import logging
import os
import random
import re
import time
import traceback
import uuid
Expand Down Expand Up @@ -129,6 +130,43 @@
WORKFLOW_DIR = Path(os.getenv("STRANDS_WORKFLOW_DIR", Path.home() / ".strands" / "workflows"))
os.makedirs(WORKFLOW_DIR, exist_ok=True)

# Allowed characters for a workflow_id. Workflow ids map directly to file names in
# WORKFLOW_DIR, so they are restricted to a conservative set and a bounded length.
_SAFE_WORKFLOW_ID = re.compile(r"^[A-Za-z0-9._-]{1,128}$")


def is_valid_workflow_id(workflow_id: str) -> bool:
"""Check that a workflow_id is a safe, single file name within WORKFLOW_DIR.

A valid id matches the allowlist pattern and is not a relative directory
reference ("." or ".."). This keeps workflow ids confined to file names
inside WORKFLOW_DIR rather than paths that point elsewhere.
"""
if not isinstance(workflow_id, str):
return False
if workflow_id in (".", ".."):
return False
return bool(_SAFE_WORKFLOW_ID.match(workflow_id))


def resolve_workflow_path(workflow_id: str) -> Path:
"""Resolve the JSON file path for a workflow_id and confine it to WORKFLOW_DIR.

Returns the resolved path of ``WORKFLOW_DIR/<workflow_id>.json`` and raises
ValueError if the result would fall outside WORKFLOW_DIR. This is the check
that provides traversal safety at every filesystem sink, independent of the
create-time allowlist at the tool boundary. The explicit "." and ".." reject
is cheap and is never a legitimate id.
"""
if not isinstance(workflow_id, str) or workflow_id in (".", ".."):
raise ValueError(f"Invalid workflow_id: {workflow_id!r}")
workflow_root = WORKFLOW_DIR.resolve()
file_path = (workflow_root / f"{workflow_id}.json").resolve()
if not file_path.is_relative_to(workflow_root):
raise ValueError(f"Resolved workflow path escapes the workflow directory: {workflow_id!r}")
return file_path


# Default thread pool settings
MIN_THREADS = int(os.getenv("STRANDS_WORKFLOW_MIN_THREADS", "2"))
MAX_THREADS = int(os.getenv("STRANDS_WORKFLOW_MAX_THREADS", "8"))
Expand Down Expand Up @@ -282,7 +320,7 @@ def _load_all_workflows(self):
def load_workflow(self, workflow_id: str) -> Optional[Dict]:
"""Load a workflow from its JSON file."""
try:
file_path = WORKFLOW_DIR / f"{workflow_id}.json"
file_path = resolve_workflow_path(workflow_id)
if file_path.exists():
with open(file_path, "r") as f:
self._workflows[workflow_id] = json.load(f)
Expand All @@ -298,7 +336,7 @@ def store_workflow(self, workflow_id: str, workflow_data: Dict) -> Dict:
self._workflows[workflow_id] = workflow_data

# Store to file
file_path = WORKFLOW_DIR / f"{workflow_id}.json"
file_path = resolve_workflow_path(workflow_id)
with open(file_path, "w") as f:
json.dump(workflow_data, f, indent=2)

Expand Down Expand Up @@ -891,7 +929,7 @@ def delete_workflow(self, workflow_id: str) -> Dict:
del self._workflows[workflow_id]

# Remove file if exists
file_path = WORKFLOW_DIR / f"{workflow_id}.json"
file_path = resolve_workflow_path(workflow_id)
if file_path.exists():
file_path.unlink()
return {
Expand Down Expand Up @@ -1118,15 +1156,33 @@ def workflow(

# Route to appropriate handler
if action == "create":
# Auto-generate an id when one is not supplied, then apply the strict
# charset allowlist to new ids only. The allowlist is intentionally
# scoped to create so pre-existing workflows whose ids contain spaces,
# unicode, or are longer than 128 characters can still be read and
# deleted. Traversal safety for every action comes from
# resolve_workflow_path, which confines each filesystem sink to
# WORKFLOW_DIR.
if not workflow_id:
workflow_id = str(uuid.uuid4())
if not is_valid_workflow_id(workflow_id):
return {
"status": "error",
"content": [
{
"text": (
"❌ Invalid workflow_id. Use 1-128 characters limited to "
"letters, digits, '.', '_' and '-'."
)
}
],
}
if not tasks:
return {
"status": "error",
"content": [{"text": "❌ Tasks are required for create action"}],
}

if not workflow_id:
workflow_id = str(uuid.uuid4())

return _manager.create_workflow(workflow_id, tasks)

elif action == "start":
Expand Down
135 changes: 135 additions & 0 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,141 @@ def test_delete_workflow_missing_id(self, mock_parent_agent):
assert "workflow_id is required for delete action" in result["content"][0]["text"]


class TestWorkflowIdValidation:
"""Test that workflow_id stays confined to the workflow directory."""

BAD_IDS = [
"../../etc/passwd",
"../../../../tmp/pwned",
"..",
".",
"foo/bar",
"/etc/passwd",
"with space",
"tab\tname",
"a" * 129,
"",
]

@pytest.mark.parametrize("bad_id", BAD_IDS)
def test_is_valid_workflow_id_rejects(self, bad_id):
"""The allowlist rejects traversal sequences, separators, and overlong ids."""
assert workflow_module.is_valid_workflow_id(bad_id) is False

@pytest.mark.parametrize("good_id", ["abc", "my_workflow", "data-pipeline", "v1.2.3", "A" * 128])
def test_is_valid_workflow_id_accepts(self, good_id):
"""The allowlist accepts ordinary ids."""
assert workflow_module.is_valid_workflow_id(good_id) is True

def test_create_boundary_rejects_traversal(self, mock_parent_agent):
"""create applies the strict allowlist and rejects a traversal id before any sink."""
with patch("strands_tools.workflow.WorkflowManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager_class.return_value = mock_manager

result = workflow_module.workflow(
action="create",
workflow_id="../../etc/passwd",
tasks=[{"task_id": "t1", "description": "d"}],
agent=mock_parent_agent,
)

assert result["status"] == "error"
assert "Invalid workflow_id" in result["content"][0]["text"]
mock_manager.create_workflow.assert_not_called()

@pytest.mark.parametrize("legacy_id", ["my workflow", "café", "a" * 200])
def test_non_create_actions_accept_legacy_ids(self, mock_parent_agent, legacy_id):
"""read/start/status/delete do not apply the create-time charset allowlist.

Pre-existing ids with spaces, unicode, or over 128 characters are no longer
rejected at the boundary; they reach the sink, which confines them to
WORKFLOW_DIR rather than failing the invalid-id check.
"""
with patch("strands_tools.workflow.WorkflowManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager_class.return_value = mock_manager

for action, sink in (
("start", mock_manager.start_workflow),
("status", mock_manager.get_workflow_status),
("delete", mock_manager.delete_workflow),
):
mock_manager.reset_mock()
result = workflow_module.workflow(action=action, workflow_id=legacy_id, agent=mock_parent_agent)
# The boundary did not reject with the invalid-id allowlist message.
assert not (
isinstance(result, dict)
and result.get("status") == "error"
and "Invalid workflow_id" in result["content"][0]["text"]
)
sink.assert_called_once_with(legacy_id)

def test_resolve_workflow_path_rejects_traversal(self):
"""The defense-in-depth resolver refuses paths that escape WORKFLOW_DIR."""
with tempfile.TemporaryDirectory() as temp_dir:
with patch.object(workflow_module, "WORKFLOW_DIR", Path(temp_dir)):
with pytest.raises(ValueError):
workflow_module.resolve_workflow_path("../../etc/passwd")

def test_resolve_workflow_path_allows_valid_id(self):
"""A valid id resolves to a file directly inside WORKFLOW_DIR."""
with tempfile.TemporaryDirectory() as temp_dir:
with patch.object(workflow_module, "WORKFLOW_DIR", Path(temp_dir)):
resolved = workflow_module.resolve_workflow_path("safe_id")
assert resolved == (Path(temp_dir).resolve() / "safe_id.json")

def test_sinks_confine_traversal_id(self):
"""store/load/delete sinks all refuse a traversal id at runtime."""
with tempfile.TemporaryDirectory() as temp_dir:
with patch.object(workflow_module, "WORKFLOW_DIR", Path(temp_dir)):
outside = Path(temp_dir).parent / "pwned.json"
manager = object.__new__(workflow_module.WorkflowManager)
manager._workflows = {}

# store_workflow returns an error status and writes nothing outside the dir.
store_result = workflow_module.WorkflowManager.store_workflow(manager, "../pwned", {"data": "x"})
assert store_result["status"] == "error"
assert not outside.exists()

# load_workflow returns None rather than reading outside the dir.
assert workflow_module.WorkflowManager.load_workflow(manager, "../pwned") is None

# delete_workflow surfaces an error rather than unlinking outside the dir.
delete_result = workflow_module.WorkflowManager.delete_workflow(manager, "../pwned")
assert delete_result["status"] == "error"

@pytest.mark.parametrize("legacy_id", ["my workflow", "café"])
def test_sinks_accept_benign_legacy_id(self, legacy_id):
"""A spaced/unicode id loads and deletes again and resolves inside WORKFLOW_DIR.

The create-time allowlist would reject these ids, but read/delete sinks only
require containment, so pre-existing workflows remain reachable.
"""
with tempfile.TemporaryDirectory() as temp_dir:
with patch.object(workflow_module, "WORKFLOW_DIR", Path(temp_dir)):
# The id resolves to a file directly inside WORKFLOW_DIR.
resolved = workflow_module.resolve_workflow_path(legacy_id)
assert resolved == (Path(temp_dir).resolve() / f"{legacy_id}.json")
assert resolved.is_relative_to(Path(temp_dir).resolve())

manager = object.__new__(workflow_module.WorkflowManager)
manager._workflows = {}

# Seed a pre-existing workflow file with this legacy id.
resolved.write_text(json.dumps({"workflow_id": legacy_id, "data": "x"}))

# load_workflow reads it back without raising the invalid-id error.
loaded = workflow_module.WorkflowManager.load_workflow(manager, legacy_id)
assert loaded is not None
assert loaded["workflow_id"] == legacy_id

# delete_workflow removes it and reports success rather than an id error.
delete_result = workflow_module.WorkflowManager.delete_workflow(manager, legacy_id)
assert delete_result["status"] == "success"
assert not resolved.exists()


class TestWorkflowManager:
"""Test WorkflowManager class functionality."""

Expand Down
Loading