From 1c5dcc2e7bae1299d7207f29de5fced293e1f530 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 11 May 2026 09:33:17 -0400 Subject: [PATCH 1/9] Implement plugins for MultiAgent --- src/strands/multiagent/graph.py | 20 + src/strands/multiagent/swarm.py | 9 + src/strands/plugins/__init__.py | 7 +- src/strands/plugins/_discovery.py | 119 +++++ src/strands/plugins/multiagent_plugin.py | 119 +++++ src/strands/plugins/multiagent_registry.py | 97 ++++ src/strands/plugins/plugin.py | 40 +- src/strands/plugins/registry.py | 19 +- .../multiagent/test_multiagent_plugins.py | 230 ++++++++ .../strands/plugins/test_multiagent_plugin.py | 498 ++++++++++++++++++ 10 files changed, 1111 insertions(+), 47 deletions(-) create mode 100644 src/strands/plugins/_discovery.py create mode 100644 src/strands/plugins/multiagent_plugin.py create mode 100644 src/strands/plugins/multiagent_registry.py create mode 100644 tests/strands/multiagent/test_multiagent_plugins.py create mode 100644 tests/strands/plugins/test_multiagent_plugin.py diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 8da8314ea..c34554f78 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -37,6 +37,8 @@ ) from ..hooks.registry import HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState +from ..plugins.multiagent_plugin import MultiAgentPlugin +from ..plugins.multiagent_registry import _MultiAgentPluginRegistry from ..session import SessionManager from ..telemetry import get_tracer from ..types._events import ( @@ -253,6 +255,7 @@ def __init__(self) -> None: self._id: str = _DEFAULT_GRAPH_ID self._session_manager: SessionManager | None = None self._hooks: list[HookProvider] | None = None + self._plugins: list[MultiAgentPlugin] | None = None def add_node(self, executor: AgentBase | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an AgentBase or MultiAgentBase instance as a node to the graph.""" @@ -370,6 +373,15 @@ def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder": self._hooks = hooks return self + def set_plugins(self, plugins: list[MultiAgentPlugin]) -> "GraphBuilder": + """Set plugins for the graph. + + Args: + plugins: List of multi-agent plugins for extending graph behavior + """ + self._plugins = plugins + return self + def build(self) -> "Graph": """Build and validate the graph with configured settings.""" if not self.nodes: @@ -397,6 +409,7 @@ def build(self) -> "Graph": reset_on_revisit=self._reset_on_revisit, session_manager=self._session_manager, hooks=self._hooks, + plugins=self._plugins, id=self._id, ) @@ -427,6 +440,7 @@ def __init__( reset_on_revisit: bool = False, session_manager: SessionManager | None = None, hooks: list[HookProvider] | None = None, + plugins: list[MultiAgentPlugin] | None = None, id: str = _DEFAULT_GRAPH_ID, trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> None: @@ -442,6 +456,7 @@ def __init__( reset_on_revisit: Whether to reset node state when revisited (default: False) session_manager: Session manager for persisting graph state and execution history (default: None) hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) + plugins: List of multi-agent plugins for extending graph behavior (default: None) id: Unique graph id (default: None) trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None) """ @@ -469,6 +484,11 @@ def __init__( for hook in hooks: self.hooks.add_hook(hook) + self._plugin_registry = _MultiAgentPluginRegistry(self) + if plugins: + for plugin in plugins: + self._plugin_registry.add_and_init(plugin) + self._resume_next_nodes: list[GraphNode] = [] self._resume_from_session = False self.id = id diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index f5731a371..67fdefb70 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -36,6 +36,8 @@ ) from ..hooks.registry import HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState +from ..plugins.multiagent_plugin import MultiAgentPlugin +from ..plugins.multiagent_registry import _MultiAgentPluginRegistry from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool @@ -247,6 +249,7 @@ def __init__( repetitive_handoff_min_unique_agents: int = 0, session_manager: SessionManager | None = None, hooks: list[HookProvider] | None = None, + plugins: list[MultiAgentPlugin] | None = None, id: str = _DEFAULT_SWARM_ID, trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> None: @@ -266,6 +269,7 @@ def __init__( Disabled by default (default: 0) session_manager: Session manager for persisting graph state and execution history (default: None) hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) + plugins: List of multi-agent plugins for extending swarm behavior (default: None) trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None) """ super().__init__() @@ -299,6 +303,11 @@ def __init__( if self.session_manager: self.hooks.add_hook(self.session_manager) + self._plugin_registry = _MultiAgentPluginRegistry(self) + if plugins: + for plugin in plugins: + self._plugin_registry.add_and_init(plugin) + self._resume_from_session = False self._setup_swarm(nodes) diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index c4b7c72c7..7a3d5fa17 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -1,13 +1,16 @@ -"""Plugin system for extending agent functionality. +"""Plugin system for extending agent and orchestrator functionality. This module provides a composable mechanism for building objects that can -extend agent behavior through automatic hook and tool registration. +extend agent and multi-agent orchestrator behavior through automatic hook +and tool registration. """ from .decorator import hook +from .multiagent_plugin import MultiAgentPlugin from .plugin import Plugin __all__ = [ + "MultiAgentPlugin", "Plugin", "hook", ] diff --git a/src/strands/plugins/_discovery.py b/src/strands/plugins/_discovery.py new file mode 100644 index 000000000..021176fe2 --- /dev/null +++ b/src/strands/plugins/_discovery.py @@ -0,0 +1,119 @@ +"""Shared utility for discovering decorated methods on plugin instances. + +This module provides helper functions used by both Plugin and MultiAgentPlugin +to scan for @hook (and optionally @tool) decorated methods, and shared registry +utilities for plugin initialization and hook registration. +""" + +import inspect +import logging +from collections.abc import Awaitable, Callable +from typing import Any, cast + +from .._async import run_async +from ..hooks.registry import HookCallback, HookRegistry +from ..tools.decorator import DecoratedFunctionTool + +logger = logging.getLogger(__name__) + + +def discover_hooks(instance: object, plugin_name: str) -> list[HookCallback]: + """Scan an instance's class hierarchy for @hook decorated methods. + + Walks the MRO in reverse so parent class hooks come first, but child + overrides win (only the child's version is included). + + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). + + Returns: + List of bound hook callback methods in declaration order. + """ + hooks: list[HookCallback] = [] + seen: set[str] = set() + + for cls in reversed(type(instance).__mro__): + for attr_name in cls.__dict__: + if attr_name in seen: + continue + seen.add(attr_name) + + try: + bound = getattr(instance, attr_name) + except Exception: + continue + + if hasattr(bound, "_hook_event_types") and callable(bound): + hooks.append(bound) + logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", plugin_name, attr_name) + + return hooks + + +def discover_tools(instance: object, plugin_name: str) -> list[DecoratedFunctionTool]: + """Scan an instance's class hierarchy for @tool decorated methods. + + Walks the MRO in reverse so parent class tools come first, but child + overrides win (only the child's version is included). + + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). + + Returns: + List of DecoratedFunctionTool instances in declaration order. + """ + tools: list[DecoratedFunctionTool] = [] + seen: set[str] = set() + + for cls in reversed(type(instance).__mro__): + for attr_name in cls.__dict__: + if attr_name in seen: + continue + seen.add(attr_name) + + try: + bound = getattr(instance, attr_name) + except Exception: + continue + + if isinstance(bound, DecoratedFunctionTool): + tools.append(bound) + logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", plugin_name, attr_name) + + return tools + + +def call_init_method(init_method: Callable[..., Any], target: Any) -> None: + """Call a plugin's init method, handling both sync and async implementations. + + Args: + init_method: The init_agent or init_multi_agent method to call. + target: The agent or orchestrator instance to pass to the init method. + """ + if inspect.iscoroutinefunction(init_method): + async_init = cast(Callable[..., Awaitable[None]], init_method) + run_async(lambda: async_init(target)) + else: + init_method(target) + + +def register_hooks(plugin_name: str, hooks: list[HookCallback], registry: HookRegistry) -> None: + """Register discovered hook callbacks with a hook registry. + + Args: + plugin_name: The plugin name (used for debug logging). + hooks: List of hook callbacks to register. + registry: The HookRegistry to register callbacks with. + """ + for hook_callback in hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + registry.add_callback(event_type, hook_callback) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + plugin_name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) diff --git a/src/strands/plugins/multiagent_plugin.py b/src/strands/plugins/multiagent_plugin.py new file mode 100644 index 000000000..89bd9e0e5 --- /dev/null +++ b/src/strands/plugins/multiagent_plugin.py @@ -0,0 +1,119 @@ +"""MultiAgentPlugin base class for extending multi-agent orchestrator functionality. + +This module defines the MultiAgentPlugin base class, which provides a composable way to +add behavior changes to multi-agent orchestrators (Swarm, Graph) through automatic hook +registration and custom initialization. + +MultiAgentPlugin is the orchestrator-level counterpart to Plugin (which targets individual agents). +A class can implement both Plugin and MultiAgentPlugin to provide functionality at both levels. +""" + +from abc import ABC, abstractmethod +from collections.abc import Awaitable +from typing import TYPE_CHECKING + +from ..hooks.registry import HookCallback +from ._discovery import discover_hooks + +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + + +class MultiAgentPlugin(ABC): + """Base class for objects that extend multi-agent orchestrator functionality. + + MultiAgentPlugins provide a composable way to add behavior changes to orchestrators + (Swarm, Graph). They support automatic discovery and registration of methods decorated + with @hook. + + Unlike agent-level Plugin, MultiAgentPlugin does not support @tool decorated methods + since orchestrators do not have tool registries. + + Attributes: + name: A stable string identifier for the plugin (must be provided by subclass) + hooks: Hooks attached to the orchestrator, auto-discovered from @hook decorated methods + + Example using decorators (recommended): + ```python + from strands.plugins import MultiAgentPlugin, hook + from strands.hooks import BeforeNodeCallEvent, AfterNodeCallEvent + + class MonitoringPlugin(MultiAgentPlugin): + name = "monitoring" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + print(f"Node {event.node_id} starting") + + @hook + def on_after_node(self, event: AfterNodeCallEvent): + print(f"Node {event.node_id} completed") + ``` + + Example with custom initialization: + ```python + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator: MultiAgentBase) -> None: + # Custom initialization logic + pass + ``` + + Dual-use example (both agent and orchestrator): + ```python + from strands.plugins import Plugin, MultiAgentPlugin, hook + from strands.hooks import BeforeInvocationEvent, BeforeNodeCallEvent + + class ObservabilityPlugin(Plugin, MultiAgentPlugin): + name = "observability" + + @hook + def on_agent_invocation(self, event: BeforeInvocationEvent): + print("Agent invocation started") + + @hook + def on_node_call(self, event: BeforeNodeCallEvent): + print(f"Node {event.node_id} starting") + + def init_agent(self, agent): + pass # Agent-level setup + + def init_multi_agent(self, orchestrator): + pass # Orchestrator-level setup + ``` + """ + + @property + @abstractmethod + def name(self) -> str: + """A stable string identifier for the plugin.""" + ... + + def __init__(self) -> None: + """Initialize the plugin and discover decorated hook methods. + + Scans the class for methods decorated with @hook and stores references + for later registration when the plugin is attached to an orchestrator. + + Uses a guard to prevent double-discovery when used with multiple inheritance + (e.g., a class that inherits from both Plugin and MultiAgentPlugin). + """ + if not hasattr(self, "_hooks"): + self._hooks: list[HookCallback] = discover_hooks(self, self.name) + + @property + def hooks(self) -> list[HookCallback]: + """List of hooks the plugin provides, auto-discovered from @hook decorated methods.""" + return self._hooks + + def init_multi_agent(self, orchestrator: "MultiAgentBase") -> None | Awaitable[None]: + """Initialize the plugin with the orchestrator instance. + + Override this method to add custom initialization logic. Decorated + hooks are automatically registered by the plugin registry. + + Args: + orchestrator: The multi-agent orchestrator instance to initialize with. + """ + return None diff --git a/src/strands/plugins/multiagent_registry.py b/src/strands/plugins/multiagent_registry.py new file mode 100644 index 000000000..d2b00ceef --- /dev/null +++ b/src/strands/plugins/multiagent_registry.py @@ -0,0 +1,97 @@ +"""MultiAgentPlugin registry for managing plugins attached to a multi-agent orchestrator. + +This module provides the _MultiAgentPluginRegistry class for tracking and managing +plugins that have been initialized with an orchestrator instance. +""" + +import logging +import weakref +from typing import Any + +from ..hooks.registry import HookRegistry +from ._discovery import call_init_method, register_hooks +from .multiagent_plugin import MultiAgentPlugin + +logger = logging.getLogger(__name__) + + +class _MultiAgentPluginRegistry: + """Registry for managing plugins attached to a multi-agent orchestrator. + + The _MultiAgentPluginRegistry tracks plugins that have been initialized with an + orchestrator, providing methods to add plugins and invoke their initialization. + + The registry handles: + 1. Calling the plugin's init_multi_agent() method for custom initialization + 2. Auto-registering discovered @hook decorated methods with the orchestrator + + Example: + ```python + registry = _MultiAgentPluginRegistry(orchestrator) + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_event(self, event: BeforeNodeCallEvent): + pass # Auto-registered by registry + + def init_multi_agent(self, orchestrator: MultiAgentBase) -> None: + # Custom logic + pass + + plugin = MyPlugin() + registry.add_and_init(plugin) + ``` + """ + + def __init__(self, orchestrator: Any) -> None: + """Initialize a plugin registry with an orchestrator reference. + + Args: + orchestrator: The orchestrator instance that plugins will be initialized with. + Must have a ``hooks`` attribute of type ``HookRegistry``. + """ + self._orchestrator_ref = weakref.ref(orchestrator) + self._plugins: dict[str, MultiAgentPlugin] = {} + + @property + def _orchestrator(self) -> Any: + """Return the orchestrator, raising ReferenceError if it has been garbage collected.""" + orchestrator = self._orchestrator_ref() + if orchestrator is None: + raise ReferenceError("Orchestrator has been garbage collected") + return orchestrator + + @property + def _hook_registry(self) -> HookRegistry: + """Return the orchestrator's hook registry.""" + return self._orchestrator.hooks # type: ignore[no-any-return] + + def add_and_init(self, plugin: MultiAgentPlugin) -> None: + """Add and initialize a plugin with the orchestrator. + + This method: + 1. Registers the plugin in the registry + 2. Calls the plugin's init_multi_agent method for custom initialization + 3. Auto-registers all discovered @hook methods with the orchestrator's hook registry + + Handles both sync and async init_multi_agent implementations automatically. + + Args: + plugin: The plugin to add and initialize. + + Raises: + ValueError: If a plugin with the same name is already registered. + """ + if plugin.name in self._plugins: + raise ValueError(f"plugin_name=<{plugin.name}> | plugin already registered") + + logger.debug("plugin_name=<%s> | registering and initializing multi-agent plugin", plugin.name) + self._plugins[plugin.name] = plugin + + # Call user's init_multi_agent for custom initialization + call_init_method(plugin.init_multi_agent, self._orchestrator) + + # Auto-register discovered hooks with the orchestrator's hook registry + register_hooks(plugin.name, plugin.hooks, self._hook_registry) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index b670de297..35633a30e 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -4,19 +4,17 @@ add behavior changes to agents through automatic hook and tool registration. """ -import logging from abc import ABC, abstractmethod from collections.abc import Awaitable from typing import TYPE_CHECKING from ..hooks.registry import HookCallback from ..tools.decorator import DecoratedFunctionTool +from ._discovery import discover_hooks, discover_tools if TYPE_CHECKING: from ..agent import Agent -logger = logging.getLogger(__name__) - class Plugin(ABC): """Base class for objects that extend agent functionality. @@ -79,10 +77,14 @@ def __init__(self) -> None: Scans the class for methods decorated with @hook and @tool and stores references for later registration when the plugin is attached to an agent. + + Uses a guard to prevent double-discovery when used with multiple inheritance + (e.g., a class that inherits from both Plugin and MultiAgentPlugin). """ - self._hooks: list[HookCallback] = [] - self._tools: list[DecoratedFunctionTool] = [] - self._discover_decorated_methods() + if not hasattr(self, "_hooks"): + self._hooks: list[HookCallback] = discover_hooks(self, self.name) + if not hasattr(self, "_tools"): + self._tools: list[DecoratedFunctionTool] = discover_tools(self, self.name) @property def hooks(self) -> list[HookCallback]: @@ -94,32 +96,6 @@ def tools(self) -> list[DecoratedFunctionTool]: """List of tools the plugin provides, auto-discovered from @tool decorated methods.""" return self._tools - def _discover_decorated_methods(self) -> None: - """Scan class for @hook and @tool decorated methods in declaration order.""" - seen: set[str] = set() - # Walk MRO so parent class hooks come first, child overrides win - for cls in reversed(type(self).__mro__): - for name in cls.__dict__: - if name in seen: - continue - seen.add(name) - - # Get the bound method from self - try: - bound = getattr(self, name) - except Exception: - continue - - # Check for @hook decorated methods - if hasattr(bound, "_hook_event_types") and callable(bound): - self._hooks.append(bound) - logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", self.name, name) - - # Check for @tool decorated methods (DecoratedFunctionTool instances) - if isinstance(bound, DecoratedFunctionTool): - self._tools.append(bound) - logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) - def init_agent(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the agent instance. diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index e994b5591..ca5d654c9 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -4,13 +4,11 @@ plugins that have been initialized with an agent instance. """ -import inspect import logging import weakref -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from .._async import run_async +from ._discovery import call_init_method from .plugin import Plugin if TYPE_CHECKING: @@ -91,13 +89,9 @@ def add_and_init(self, plugin: Plugin) -> None: self._plugins[plugin.name] = plugin # Call user's init_agent for custom initialization - if inspect.iscoroutinefunction(plugin.init_agent): - async_plugin_init = cast(Callable[..., Awaitable[None]], plugin.init_agent) - run_async(lambda: async_plugin_init(self._agent)) - else: - plugin.init_agent(self._agent) + call_init_method(plugin.init_agent, self._agent) - # Auto-register discovered hooks with the agent's hook registry + # Auto-register discovered hooks with the agent self._register_hooks(plugin) # Auto-register discovered tools with the agent's tool registry @@ -106,9 +100,8 @@ def add_and_init(self, plugin: Plugin) -> None: def _register_hooks(self, plugin: Plugin) -> None: """Register all discovered hooks from the plugin with the agent. - Warns if a hook callback is already registered for an event type, - which can happen when init_agent() manually registers a hook that - is also decorated with @hook. + Uses agent.add_hook() rather than the hook registry directly, so that + the agent can track registrations through its public API. Args: plugin: The plugin whose hooks should be registered. diff --git a/tests/strands/multiagent/test_multiagent_plugins.py b/tests/strands/multiagent/test_multiagent_plugins.py new file mode 100644 index 000000000..2052c471e --- /dev/null +++ b/tests/strands/multiagent/test_multiagent_plugins.py @@ -0,0 +1,230 @@ +"""Tests for MultiAgentPlugin integration with Swarm and Graph.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands.hooks import BeforeNodeCallEvent +from strands.hooks.registry import HookProvider +from strands.multiagent import GraphBuilder, Swarm +from strands.multiagent.graph import Graph, GraphNode +from strands.plugins import MultiAgentPlugin, hook + +# --- Fixtures --- + + +@pytest.fixture +def mock_swarm_agent(): + """Create a mock agent suitable for Swarm construction.""" + agent = MagicMock() + agent.name = "agent1" + agent.description = "Test agent" + agent.messages = [] + agent.state = MagicMock() + agent.state.get.return_value = {} + agent._model_state = {} + agent._session_manager = None + agent.tool_registry = MagicMock() + agent.tool_registry.get_all_tools_config.return_value = {} + return agent + + +@pytest.fixture +def mock_graph_agent(): + """Create a mock agent suitable for Graph construction.""" + agent = MagicMock() + agent.name = "agent1" + agent.messages = [] + agent.state = MagicMock() + agent.state.get.return_value = {} + agent._model_state = {} + agent._session_manager = None + return agent + + +def _make_swarm(agent, **kwargs): + """Helper to construct a Swarm with tracer patched out.""" + with patch("strands.multiagent.swarm.get_tracer"): + return Swarm(nodes=[agent], **kwargs) + + +def _make_graph(agent, **kwargs): + """Helper to construct a Graph with tracer patched out.""" + with patch("strands.multiagent.graph.get_tracer"): + node = GraphNode(node_id="agent1", executor=agent) + return Graph(nodes={"agent1": node}, edges=set(), entry_points={node}, **kwargs) + + +# --- Swarm plugin integration tests --- + + +def test_swarm_accepts_plugins_parameter(mock_swarm_agent): + """Test that Swarm constructor accepts a plugins parameter.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert swarm._plugin_registry is not None + + +def test_swarm_initializes_plugins(mock_swarm_agent): + """Test that Swarm calls init_multi_agent on plugins during construction.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert init_called + + +def test_swarm_registers_plugin_hooks(mock_swarm_agent): + """Test that Swarm registers plugin hooks with its hook registry.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_swarm_plugins_coexist_with_hooks(mock_swarm_agent): + """Test that plugins and legacy hooks parameter work together.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class MyHookProvider(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.on_before_node) + + def on_before_node(self, event): + pass + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()], hooks=[MyHookProvider()]) + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 2 + + +def test_swarm_duplicate_plugin_raises_error(mock_swarm_agent): + """Test that duplicate plugin names raise an error in Swarm.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + with pytest.raises(ValueError, match="plugin already registered"): + _make_swarm(mock_swarm_agent, plugins=[MyPlugin(), MyPlugin()]) + + +def test_swarm_no_plugins_parameter(mock_swarm_agent): + """Test that Swarm works without plugins parameter (backward compat).""" + swarm = _make_swarm(mock_swarm_agent) + assert swarm._plugin_registry is not None + + +# --- Graph plugin integration tests --- + + +def test_graph_builder_accepts_plugins(): + """Test that GraphBuilder has a set_plugins method.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + builder = GraphBuilder() + result = builder.set_plugins([MyPlugin()]) + assert result is builder + + +def test_graph_accepts_plugins_parameter(mock_graph_agent): + """Test that Graph constructor accepts a plugins parameter.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert graph._plugin_registry is not None + + +def test_graph_initializes_plugins(mock_graph_agent): + """Test that Graph calls init_multi_agent on plugins during construction.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert init_called + + +def test_graph_registers_plugin_hooks(mock_graph_agent): + """Test that Graph registers plugin hooks with its hook registry.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_plugins_coexist_with_hooks(mock_graph_agent): + """Test that plugins and legacy hooks parameter work together in Graph.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class MyHookProvider(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.on_before_node) + + def on_before_node(self, event): + pass + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()], hooks=[MyHookProvider()]) + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 2 + + +def test_graph_builder_passes_plugins_to_graph(mock_graph_agent): + """Test that GraphBuilder.build() passes plugins to the Graph constructor.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + with patch("strands.multiagent.graph.get_tracer"): + builder = GraphBuilder() + builder.add_node(mock_graph_agent, node_id="agent1") + builder.set_entry_point("agent1") + builder.set_plugins([MyPlugin()]) + graph = builder.build() + + assert init_called + assert graph._plugin_registry is not None diff --git a/tests/strands/plugins/test_multiagent_plugin.py b/tests/strands/plugins/test_multiagent_plugin.py new file mode 100644 index 000000000..d8bf06005 --- /dev/null +++ b/tests/strands/plugins/test_multiagent_plugin.py @@ -0,0 +1,498 @@ +"""Tests for the MultiAgentPlugin base class and registry.""" + +import gc +import unittest.mock + +import pytest + +from strands.hooks import AfterNodeCallEvent, BeforeNodeCallEvent, HookRegistry +from strands.plugins import Plugin, hook +from strands.plugins.multiagent_plugin import MultiAgentPlugin +from strands.plugins.multiagent_registry import _MultiAgentPluginRegistry +from strands.plugins.registry import _PluginRegistry + +# --- Fixtures --- + + +@pytest.fixture +def mock_orchestrator(): + """Create a mock orchestrator with a working hook registry.""" + orch = unittest.mock.MagicMock() + orch.hooks = HookRegistry() + orch.add_hook = unittest.mock.Mock( + side_effect=lambda callback, event_type=None: orch.hooks.add_callback(event_type, callback) + ) + return orch + + +@pytest.fixture +def registry(mock_orchestrator): + """Create a _MultiAgentPluginRegistry backed by the mock orchestrator.""" + return _MultiAgentPluginRegistry(mock_orchestrator) + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with a working hook registry for dual-plugin tests.""" + agent = unittest.mock.MagicMock() + agent.hooks = HookRegistry() + agent.add_hook = unittest.mock.Mock( + side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback) + ) + agent.tool_registry = unittest.mock.MagicMock() + return agent + + +# --- MultiAgentPlugin base class tests --- + + +def test_multiagent_plugin_is_class_not_protocol(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert isinstance(MyPlugin(), MultiAgentPlugin) + + +def test_multiagent_plugin_requires_name_attribute(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert MyPlugin().name == "my-plugin" + + +def test_multiagent_plugin_name_as_property(): + class MyPlugin(MultiAgentPlugin): + @property + def name(self) -> str: + return "property-plugin" + + assert MyPlugin().name == "property-plugin" + + +def test_multiagent_plugin_requires_name(): + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + + class PluginWithoutName(MultiAgentPlugin): + def init_multi_agent(self, orchestrator): + pass + + PluginWithoutName() + + +def test_multiagent_plugin_provides_default_init_multi_agent(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert MyPlugin().init_multi_agent(unittest.mock.MagicMock()) is None + + +# --- Auto-discovery tests --- + + +def test_discovers_hook_decorated_methods(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "on_before_node" + + +def test_discovers_multiple_hooks(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeNodeCallEvent): + pass + + @hook + def hook2(self, event: AfterNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + assert {h.__name__ for h in plugin.hooks} == {"hook1", "hook2"} + + +def test_hooks_preserve_definition_order(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def z_last(self, event: BeforeNodeCallEvent): + pass + + @hook + def a_first(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert [h.__name__ for h in plugin.hooks] == ["z_last", "a_first"] + + +def test_ignores_non_decorated_methods(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def regular_method(self): + pass + + @hook + def decorated_hook(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "decorated_hook" + + +def test_no_tool_support(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert not hasattr(MyPlugin(), "tools") + + +# --- Registry tests --- + + +def test_registry_add_and_init_calls_init_multi_agent(registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + def __init__(self): + super().__init__() + self.initialized = False + + def init_multi_agent(self, orchestrator): + self.initialized = True + + plugin = TestPlugin() + registry.add_and_init(plugin) + assert plugin.initialized + + +def test_registry_add_duplicate_raises_error(registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + registry.add_and_init(TestPlugin()) + with pytest.raises(ValueError, match="plugin_name= | plugin already registered"): + registry.add_and_init(TestPlugin()) + + +def test_registry_registers_discovered_hooks(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + registry.add_and_init(TestPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_registry_registers_multiple_hooks(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + @hook + def on_after_node(self, event: AfterNodeCallEvent): + pass + + registry.add_and_init(TestPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +def test_registry_async_init_multi_agent_supported(registry): + async_init_called = False + + class AsyncPlugin(MultiAgentPlugin): + name = "async-plugin" + + async def init_multi_agent(self, orchestrator): + nonlocal async_init_called + async_init_called = True + + registry.add_and_init(AsyncPlugin()) + assert async_init_called + + +def test_registry_hooks_are_bound_to_instance(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + def __init__(self): + super().__init__() + self.events_received = [] + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + self.events_received.append(event) + + plugin = TestPlugin() + registry.add_and_init(plugin) + + mock_event = unittest.mock.MagicMock(spec=BeforeNodeCallEvent) + mock_orchestrator.hooks._registered_callbacks[BeforeNodeCallEvent][0](mock_event) + + assert plugin.events_received == [mock_event] + + +def test_registry_raises_reference_error_after_orchestrator_collected(): + orch = unittest.mock.MagicMock() + orch.hooks = HookRegistry() + reg = _MultiAgentPluginRegistry(orch) + del orch + gc.collect() + + with pytest.raises(ReferenceError, match="Orchestrator has been garbage collected"): + _ = reg._orchestrator + + +def test_registry_init_multi_agent_called_before_hook_registration(mock_orchestrator): + call_order = [] + + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + call_order.append("init") + + original = mock_orchestrator.hooks.add_callback + + def tracking(event_type, callback): + call_order.append("hook") + return original(event_type, callback) + + mock_orchestrator.hooks.add_callback = tracking + + registry = _MultiAgentPluginRegistry(mock_orchestrator) + registry.add_and_init(TestPlugin()) + + assert call_order == ["init", "hook"] + + +# --- Union type tests --- + + +def test_registers_hook_for_union_types(mock_orchestrator, registry): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_node_events(self, event: BeforeNodeCallEvent | AfterNodeCallEvent): + pass + + registry.add_and_init(MyPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +# --- Subclass override tests --- + + +def test_subclass_can_override_init_multi_agent(mock_orchestrator, registry): + custom_init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + nonlocal custom_init_called + custom_init_called = True + + registry.add_and_init(MyPlugin()) + assert custom_init_called + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_subclass_can_add_manual_hooks_in_init(mock_orchestrator, registry): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def auto_hook(self, event: BeforeNodeCallEvent): + pass + + def manual_hook(self, event: AfterNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + orchestrator.hooks.add_callback(AfterNodeCallEvent, self.manual_hook) + + registry.add_and_init(MyPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +# --- Inheritance tests --- + + +def test_child_inherits_parent_hooks(): + class ParentPlugin(MultiAgentPlugin): + name = "parent-plugin" + + @hook + def parent_hook(self, event: BeforeNodeCallEvent): + pass + + class ChildPlugin(ParentPlugin): + name = "child-plugin" + + @hook + def child_hook(self, event: AfterNodeCallEvent): + pass + + plugin = ChildPlugin() + assert len(plugin.hooks) == 2 + assert {h.__name__ for h in plugin.hooks} == {"parent_hook", "child_hook"} + + +def test_child_can_override_parent_hook(): + class ParentPlugin(MultiAgentPlugin): + name = "parent-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class ChildPlugin(ParentPlugin): + name = "child-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + assert len(ChildPlugin().hooks) == 1 + + +# --- Dual plugin tests --- + + +def test_dual_plugin_isinstance_checks(): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + plugin = DualPlugin() + assert isinstance(plugin, Plugin) + assert isinstance(plugin, MultiAgentPlugin) + + +def test_dual_plugin_discovers_hooks_once(): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + assert len(DualPlugin().hooks) == 1 + + +def test_dual_plugin_has_both_init_methods(mock_agent, mock_orchestrator): + agent_init_called = False + multi_agent_init_called = False + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + def init_agent(self, agent): + nonlocal agent_init_called + agent_init_called = True + + def init_multi_agent(self, orchestrator): + nonlocal multi_agent_init_called + multi_agent_init_called = True + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + assert agent_init_called + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + assert multi_agent_init_called + + +def test_dual_plugin_registers_hooks_in_both_contexts(mock_agent, mock_orchestrator): + from strands.hooks import BeforeModelCallEvent + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + pass + + @hook + def on_node_call(self, event: BeforeNodeCallEvent): + pass + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_dual_plugin_shared_state(mock_agent, mock_orchestrator): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + def __init__(self): + super().__init__() + self.call_count = 0 + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + self.call_count += 1 + + def init_agent(self, agent): + self.call_count += 10 + + def init_multi_agent(self, orchestrator): + self.call_count += 100 + + plugin = DualPlugin() + _PluginRegistry(mock_agent).add_and_init(plugin) + assert plugin.call_count == 10 + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(plugin) + assert plugin.call_count == 110 + + +def test_dual_plugin_tools_only_for_agent(mock_agent, mock_orchestrator): + from strands.tools.decorator import tool + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + mock_agent.tool_registry.process_tools.assert_called_once() + + # Orchestrator has no tool registration + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) From 7aae080976e6825bb0605651db340955d6aca7a6 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 11 May 2026 10:42:47 -0400 Subject: [PATCH 2/9] Address self-review --- AGENTS.md | 5 ++++- src/strands/__init__.py | 3 ++- src/strands/plugins/multiagent_registry.py | 19 +++++++++++++++---- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 0b877ea98..daddbbb2d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -130,8 +130,11 @@ strands-agents/ │ │ │ ├── plugins/ # Plugin system │ │ ├── plugin.py # Plugin base class +│ │ ├── multiagent_plugin.py # MultiAgentPlugin base class │ │ ├── decorator.py # @hook decorator -│ │ └── registry.py # PluginRegistry for tracking plugins +│ │ ├── registry.py # PluginRegistry for tracking agent plugins +│ │ ├── multiagent_registry.py # Registry for tracking orchestrator plugins +│ │ └── _discovery.py # Shared hook/tool discovery utilities │ │ │ ├── handlers/ # Event handlers │ │ └── callback_handler.py # Callback handling diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 6625ac41f..00e32ead3 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin +from .plugins import MultiAgentPlugin, Plugin from .tools.decorator import tool from .types._snapshot import Snapshot from .types.tools import ToolContext @@ -17,6 +17,7 @@ "agent", "models", "ModelRetryStrategy", + "MultiAgentPlugin", "Plugin", "Skill", "Snapshot", diff --git a/src/strands/plugins/multiagent_registry.py b/src/strands/plugins/multiagent_registry.py index d2b00ceef..247bfd3a6 100644 --- a/src/strands/plugins/multiagent_registry.py +++ b/src/strands/plugins/multiagent_registry.py @@ -6,12 +6,15 @@ import logging import weakref -from typing import Any +from typing import TYPE_CHECKING from ..hooks.registry import HookRegistry from ._discovery import call_init_method, register_hooks from .multiagent_plugin import MultiAgentPlugin +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) @@ -45,18 +48,26 @@ def init_multi_agent(self, orchestrator: MultiAgentBase) -> None: ``` """ - def __init__(self, orchestrator: Any) -> None: + def __init__(self, orchestrator: "MultiAgentBase") -> None: """Initialize a plugin registry with an orchestrator reference. Args: orchestrator: The orchestrator instance that plugins will be initialized with. Must have a ``hooks`` attribute of type ``HookRegistry``. + + Raises: + TypeError: If the orchestrator does not have a ``hooks`` attribute. """ + if not hasattr(orchestrator, "hooks"): + raise TypeError( + f"{type(orchestrator).__name__} does not have a 'hooks' attribute; " + "plugins require an orchestrator with a HookRegistry" + ) self._orchestrator_ref = weakref.ref(orchestrator) self._plugins: dict[str, MultiAgentPlugin] = {} @property - def _orchestrator(self) -> Any: + def _orchestrator(self) -> "MultiAgentBase": """Return the orchestrator, raising ReferenceError if it has been garbage collected.""" orchestrator = self._orchestrator_ref() if orchestrator is None: @@ -66,7 +77,7 @@ def _orchestrator(self) -> Any: @property def _hook_registry(self) -> HookRegistry: """Return the orchestrator's hook registry.""" - return self._orchestrator.hooks # type: ignore[no-any-return] + return self._orchestrator.hooks # type: ignore[attr-defined, no-any-return] def add_and_init(self, plugin: MultiAgentPlugin) -> None: """Add and initialize a plugin with the orchestrator. From 57d89d1683fc57a43ffd5a0af55a98801aa61f9b Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 11 May 2026 11:13:01 -0400 Subject: [PATCH 3/9] Move to kwarg --- src/strands/multiagent/graph.py | 4 ++-- src/strands/multiagent/swarm.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index c34554f78..8f30fe21a 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -440,9 +440,9 @@ def __init__( reset_on_revisit: bool = False, session_manager: SessionManager | None = None, hooks: list[HookProvider] | None = None, - plugins: list[MultiAgentPlugin] | None = None, id: str = _DEFAULT_GRAPH_ID, trace_attributes: Mapping[str, AttributeValue] | None = None, + plugins: list[MultiAgentPlugin] | None = None, ) -> None: """Initialize Graph with execution limits and reset behavior. @@ -456,9 +456,9 @@ def __init__( reset_on_revisit: Whether to reset node state when revisited (default: False) session_manager: Session manager for persisting graph state and execution history (default: None) hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) - plugins: List of multi-agent plugins for extending graph behavior (default: None) id: Unique graph id (default: None) trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None) + plugins: List of multi-agent plugins for extending graph behavior (default: None) """ super().__init__() diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 67fdefb70..508779836 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -249,9 +249,9 @@ def __init__( repetitive_handoff_min_unique_agents: int = 0, session_manager: SessionManager | None = None, hooks: list[HookProvider] | None = None, - plugins: list[MultiAgentPlugin] | None = None, id: str = _DEFAULT_SWARM_ID, trace_attributes: Mapping[str, AttributeValue] | None = None, + plugins: list[MultiAgentPlugin] | None = None, ) -> None: """Initialize Swarm with agents and configuration. @@ -269,8 +269,8 @@ def __init__( Disabled by default (default: 0) session_manager: Session manager for persisting graph state and execution history (default: None) hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) - plugins: List of multi-agent plugins for extending swarm behavior (default: None) trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None) + plugins: List of multi-agent plugins for extending swarm behavior (default: None) """ super().__init__() self.id = id From d413d5248461bc793076f6611e9b98de0ace3977 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 11 May 2026 11:44:13 -0400 Subject: [PATCH 4/9] fix: address review feedback - coverage gaps and param ordering --- src/strands/multiagent/graph.py | 2 +- .../strands/plugins/test_multiagent_plugin.py | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 8f30fe21a..0ecfb8904 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -409,8 +409,8 @@ def build(self) -> "Graph": reset_on_revisit=self._reset_on_revisit, session_manager=self._session_manager, hooks=self._hooks, - plugins=self._plugins, id=self._id, + plugins=self._plugins, ) def _validate_graph(self) -> None: diff --git a/tests/strands/plugins/test_multiagent_plugin.py b/tests/strands/plugins/test_multiagent_plugin.py index d8bf06005..faa26798c 100644 --- a/tests/strands/plugins/test_multiagent_plugin.py +++ b/tests/strands/plugins/test_multiagent_plugin.py @@ -496,3 +496,53 @@ def my_tool(self, param: str) -> str: # Orchestrator has no tool registration _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + + +# --- TypeError guard tests --- + + +def test_registry_raises_type_error_if_orchestrator_missing_hooks(): + """Test that _MultiAgentPluginRegistry raises TypeError if orchestrator has no hooks attribute.""" + orchestrator_without_hooks = unittest.mock.MagicMock(spec=[]) # spec=[] means no attributes + + with pytest.raises(TypeError, match="does not have a 'hooks' attribute"): + _MultiAgentPluginRegistry(orchestrator_without_hooks) + + +# --- Double-discovery guard tests --- + + +def test_dual_plugin_hasattr_guard_prevents_double_discovery(): + """Test that the hasattr guard in __init__ prevents hooks from being discovered twice.""" + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def shared_hook(self, event: BeforeNodeCallEvent): + pass + + plugin = DualPlugin() + # If double-discovery occurred, we'd see 2 hooks instead of 1 + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "shared_hook" + + +def test_multiagent_plugin_hasattr_guard_with_pre_set_hooks(): + """Test that MultiAgentPlugin.__init__ skips discovery if _hooks already set.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def __init__(self): + # Pre-set _hooks before super().__init__ + self._hooks = [] + super().__init__() + + @hook + def should_not_be_discovered(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + # The guard should have skipped discovery since _hooks was already set + assert len(plugin.hooks) == 0 From 9cc15af4175583b2281465b3c4e9e968af51b188 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 11 May 2026 12:24:18 -0400 Subject: [PATCH 5/9] feat: add add_hook to Graph, Swarm, and MultiAgentBase --- src/strands/multiagent/base.py | 15 ++++++ src/strands/multiagent/graph.py | 13 ++++- src/strands/multiagent/swarm.py | 13 ++++- src/strands/plugins/multiagent_registry.py | 32 +++++++---- .../multiagent/test_multiagent_plugins.py | 53 +++++++++++++++++++ 5 files changed, 115 insertions(+), 11 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index dc3258f68..3e38e1cd4 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -13,6 +13,7 @@ from .._async import run_async from ..agent import AgentResult +from ..hooks.registry import HookCallback from ..interrupt import Interrupt from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput @@ -254,6 +255,20 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the orchestrator. + + Subclasses that support hooks should override this method to register + the callback with their hook registry. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + raise NotImplementedError + def _parse_trace_attributes( self, attributes: Mapping[str, AttributeValue] | None = None ) -> dict[str, AttributeValue]: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 0ecfb8904..146a31563 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -35,7 +35,7 @@ BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks.registry import HookProvider, HookRegistry +from ..hooks.registry import HookCallback, HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState from ..plugins.multiagent_plugin import MultiAgentPlugin from ..plugins.multiagent_registry import _MultiAgentPluginRegistry @@ -495,6 +495,17 @@ def __init__( run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the graph. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + self.hooks.add_callback(event_type, callback) + def __call__( self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> GraphResult: diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 508779836..3193a810a 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -34,7 +34,7 @@ BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks.registry import HookProvider, HookRegistry +from ..hooks.registry import HookCallback, HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState from ..plugins.multiagent_plugin import MultiAgentPlugin from ..plugins.multiagent_registry import _MultiAgentPluginRegistry @@ -314,6 +314,17 @@ def __init__( self._inject_swarm_tools() run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the swarm. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + self.hooks.add_callback(event_type, callback) + def __call__( self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> SwarmResult: diff --git a/src/strands/plugins/multiagent_registry.py b/src/strands/plugins/multiagent_registry.py index 247bfd3a6..69de3d3d8 100644 --- a/src/strands/plugins/multiagent_registry.py +++ b/src/strands/plugins/multiagent_registry.py @@ -8,8 +8,7 @@ import weakref from typing import TYPE_CHECKING -from ..hooks.registry import HookRegistry -from ._discovery import call_init_method, register_hooks +from ._discovery import call_init_method from .multiagent_plugin import MultiAgentPlugin if TYPE_CHECKING: @@ -74,11 +73,6 @@ def _orchestrator(self) -> "MultiAgentBase": raise ReferenceError("Orchestrator has been garbage collected") return orchestrator - @property - def _hook_registry(self) -> HookRegistry: - """Return the orchestrator's hook registry.""" - return self._orchestrator.hooks # type: ignore[attr-defined, no-any-return] - def add_and_init(self, plugin: MultiAgentPlugin) -> None: """Add and initialize a plugin with the orchestrator. @@ -104,5 +98,25 @@ def add_and_init(self, plugin: MultiAgentPlugin) -> None: # Call user's init_multi_agent for custom initialization call_init_method(plugin.init_multi_agent, self._orchestrator) - # Auto-register discovered hooks with the orchestrator's hook registry - register_hooks(plugin.name, plugin.hooks, self._hook_registry) + # Auto-register discovered hooks with the orchestrator + self._register_hooks(plugin) + + def _register_hooks(self, plugin: MultiAgentPlugin) -> None: + """Register all discovered hooks from the plugin with the orchestrator. + + Uses orchestrator.add_hook() so that the orchestrator can track + registrations through its public API. + + Args: + plugin: The plugin whose hooks should be registered. + """ + for hook_callback in plugin.hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + self._orchestrator.add_hook(hook_callback, event_type) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + plugin.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) diff --git a/tests/strands/multiagent/test_multiagent_plugins.py b/tests/strands/multiagent/test_multiagent_plugins.py index 2052c471e..85cc8d817 100644 --- a/tests/strands/multiagent/test_multiagent_plugins.py +++ b/tests/strands/multiagent/test_multiagent_plugins.py @@ -228,3 +228,56 @@ def init_multi_agent(self, orchestrator): assert init_called assert graph._plugin_registry is not None + + +# --- add_hook method tests --- + + +def test_swarm_add_hook_registers_callback(mock_swarm_agent): + """Test that Swarm.add_hook registers a callback directly.""" + events_received = [] + + def on_before_node(event: BeforeNodeCallEvent): + events_received.append(event) + + swarm = _make_swarm(mock_swarm_agent) + swarm.add_hook(on_before_node, BeforeNodeCallEvent) + + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_add_hook_registers_callback(mock_graph_agent): + """Test that Graph.add_hook registers a callback directly.""" + events_received = [] + + def on_before_node(event: BeforeNodeCallEvent): + events_received.append(event) + + graph = _make_graph(mock_graph_agent) + graph.add_hook(on_before_node, BeforeNodeCallEvent) + + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_swarm_add_hook_infers_event_type(mock_swarm_agent): + """Test that Swarm.add_hook can infer event type from type hint.""" + + def on_before_node(event: BeforeNodeCallEvent): + pass + + swarm = _make_swarm(mock_swarm_agent) + swarm.add_hook(on_before_node) + + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_add_hook_infers_event_type(mock_graph_agent): + """Test that Graph.add_hook can infer event type from type hint.""" + + def on_before_node(event: BeforeNodeCallEvent): + pass + + graph = _make_graph(mock_graph_agent) + graph.add_hook(on_before_node) + + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 From 14ff42a1e3e9ee9fdbc8fe9148185f7007097390 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 11 May 2026 13:41:49 -0400 Subject: [PATCH 6/9] refactor: remove redundant hasattr check from plugin registry --- src/strands/plugins/multiagent_registry.py | 9 --------- tests/strands/plugins/test_multiagent_plugin.py | 11 ----------- 2 files changed, 20 deletions(-) diff --git a/src/strands/plugins/multiagent_registry.py b/src/strands/plugins/multiagent_registry.py index 69de3d3d8..365c8f9c5 100644 --- a/src/strands/plugins/multiagent_registry.py +++ b/src/strands/plugins/multiagent_registry.py @@ -52,16 +52,7 @@ def __init__(self, orchestrator: "MultiAgentBase") -> None: Args: orchestrator: The orchestrator instance that plugins will be initialized with. - Must have a ``hooks`` attribute of type ``HookRegistry``. - - Raises: - TypeError: If the orchestrator does not have a ``hooks`` attribute. """ - if not hasattr(orchestrator, "hooks"): - raise TypeError( - f"{type(orchestrator).__name__} does not have a 'hooks' attribute; " - "plugins require an orchestrator with a HookRegistry" - ) self._orchestrator_ref = weakref.ref(orchestrator) self._plugins: dict[str, MultiAgentPlugin] = {} diff --git a/tests/strands/plugins/test_multiagent_plugin.py b/tests/strands/plugins/test_multiagent_plugin.py index faa26798c..a507c85cf 100644 --- a/tests/strands/plugins/test_multiagent_plugin.py +++ b/tests/strands/plugins/test_multiagent_plugin.py @@ -498,17 +498,6 @@ def my_tool(self, param: str) -> str: _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) -# --- TypeError guard tests --- - - -def test_registry_raises_type_error_if_orchestrator_missing_hooks(): - """Test that _MultiAgentPluginRegistry raises TypeError if orchestrator has no hooks attribute.""" - orchestrator_without_hooks = unittest.mock.MagicMock(spec=[]) # spec=[] means no attributes - - with pytest.raises(TypeError, match="does not have a 'hooks' attribute"): - _MultiAgentPluginRegistry(orchestrator_without_hooks) - - # --- Double-discovery guard tests --- From 51613f3bd5bd3b6bfbfcf8f59e6bc360fd611adc Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Tue, 12 May 2026 09:19:55 -0400 Subject: [PATCH 7/9] refactor: remove dead register_hooks helper and improve add_hook error message --- src/strands/multiagent/base.py | 2 +- src/strands/plugins/_discovery.py | 20 +------------------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 3e38e1cd4..14c4d0d14 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -267,7 +267,7 @@ def add_hook(self, callback: HookCallback, event_type: type | list[type] | None Can be a single type, a list of types, or None to infer from the callback's first parameter type hint. """ - raise NotImplementedError + raise NotImplementedError(f"{type(self).__name__} must implement add_hook() to support plugins") def _parse_trace_attributes( self, attributes: Mapping[str, AttributeValue] | None = None diff --git a/src/strands/plugins/_discovery.py b/src/strands/plugins/_discovery.py index 021176fe2..e6d21c8bd 100644 --- a/src/strands/plugins/_discovery.py +++ b/src/strands/plugins/_discovery.py @@ -11,7 +11,7 @@ from typing import Any, cast from .._async import run_async -from ..hooks.registry import HookCallback, HookRegistry +from ..hooks.registry import HookCallback from ..tools.decorator import DecoratedFunctionTool logger = logging.getLogger(__name__) @@ -99,21 +99,3 @@ def call_init_method(init_method: Callable[..., Any], target: Any) -> None: init_method(target) -def register_hooks(plugin_name: str, hooks: list[HookCallback], registry: HookRegistry) -> None: - """Register discovered hook callbacks with a hook registry. - - Args: - plugin_name: The plugin name (used for debug logging). - hooks: List of hook callbacks to register. - registry: The HookRegistry to register callbacks with. - """ - for hook_callback in hooks: - event_types = getattr(hook_callback, "_hook_event_types", []) - for event_type in event_types: - registry.add_callback(event_type, hook_callback) - logger.debug( - "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", - plugin_name, - getattr(hook_callback, "__name__", repr(hook_callback)), - event_type.__name__, - ) From 9cff053a3450ae7edc0d3fa5ab2f46f1678a7ac3 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Wed, 13 May 2026 12:05:18 -0400 Subject: [PATCH 8/9] refactor: DRY up discovery helpers and add guard test - Extract _discover_methods generic helper in _discovery.py to eliminate duplication between discover_hooks and discover_tools - Add test_dual_plugin_discover_hooks_called_once to verify the hasattr guard prevents double-discovery in dual inheritance --- src/strands/plugins/_discovery.py | 70 ++++++++++--------- .../strands/plugins/test_multiagent_plugin.py | 26 +++++++ 2 files changed, 63 insertions(+), 33 deletions(-) diff --git a/src/strands/plugins/_discovery.py b/src/strands/plugins/_discovery.py index e6d21c8bd..646cd1c1a 100644 --- a/src/strands/plugins/_discovery.py +++ b/src/strands/plugins/_discovery.py @@ -8,7 +8,7 @@ import inspect import logging from collections.abc import Awaitable, Callable -from typing import Any, cast +from typing import Any, TypeVar, cast from .._async import run_async from ..hooks.registry import HookCallback @@ -16,21 +16,25 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") -def discover_hooks(instance: object, plugin_name: str) -> list[HookCallback]: - """Scan an instance's class hierarchy for @hook decorated methods. - Walks the MRO in reverse so parent class hooks come first, but child +def _discover_methods(instance: object, plugin_name: str, predicate: Callable[[object], bool], label: str) -> list[Any]: + """Scan an instance's class hierarchy for methods matching a predicate. + + Walks the MRO in reverse so parent class methods come first, but child overrides win (only the child's version is included). Args: instance: The plugin instance to scan. plugin_name: The plugin name (used for debug logging). + predicate: Function that returns True for attributes to collect. + label: Label for debug logging (e.g., "hook", "tool"). Returns: - List of bound hook callback methods in declaration order. + List of matching bound methods/descriptors in declaration order. """ - hooks: list[HookCallback] = [] + results: list[Any] = [] seen: set[str] = set() for cls in reversed(type(instance).__mro__): @@ -44,45 +48,47 @@ def discover_hooks(instance: object, plugin_name: str) -> list[HookCallback]: except Exception: continue - if hasattr(bound, "_hook_event_types") and callable(bound): - hooks.append(bound) - logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", plugin_name, attr_name) + if predicate(bound): + results.append(bound) + logger.debug("plugin=<%s>, %s=<%s> | discovered", plugin_name, label, attr_name) - return hooks + return results -def discover_tools(instance: object, plugin_name: str) -> list[DecoratedFunctionTool]: - """Scan an instance's class hierarchy for @tool decorated methods. - - Walks the MRO in reverse so parent class tools come first, but child - overrides win (only the child's version is included). +def discover_hooks(instance: object, plugin_name: str) -> list[HookCallback]: + """Scan an instance's class hierarchy for @hook decorated methods. Args: instance: The plugin instance to scan. plugin_name: The plugin name (used for debug logging). Returns: - List of DecoratedFunctionTool instances in declaration order. + List of bound hook callback methods in declaration order. """ - tools: list[DecoratedFunctionTool] = [] - seen: set[str] = set() + return _discover_methods( + instance, + plugin_name, + predicate=lambda bound: hasattr(bound, "_hook_event_types") and callable(bound), + label="hook", + ) - for cls in reversed(type(instance).__mro__): - for attr_name in cls.__dict__: - if attr_name in seen: - continue - seen.add(attr_name) - try: - bound = getattr(instance, attr_name) - except Exception: - continue +def discover_tools(instance: object, plugin_name: str) -> list[DecoratedFunctionTool]: + """Scan an instance's class hierarchy for @tool decorated methods. - if isinstance(bound, DecoratedFunctionTool): - tools.append(bound) - logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", plugin_name, attr_name) + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). - return tools + Returns: + List of DecoratedFunctionTool instances in declaration order. + """ + return _discover_methods( + instance, + plugin_name, + predicate=lambda bound: isinstance(bound, DecoratedFunctionTool), + label="tool", + ) def call_init_method(init_method: Callable[..., Any], target: Any) -> None: @@ -97,5 +103,3 @@ def call_init_method(init_method: Callable[..., Any], target: Any) -> None: run_async(lambda: async_init(target)) else: init_method(target) - - diff --git a/tests/strands/plugins/test_multiagent_plugin.py b/tests/strands/plugins/test_multiagent_plugin.py index a507c85cf..b7e16c9eb 100644 --- a/tests/strands/plugins/test_multiagent_plugin.py +++ b/tests/strands/plugins/test_multiagent_plugin.py @@ -409,6 +409,32 @@ def on_before_node(self, event: BeforeNodeCallEvent): assert len(DualPlugin().hooks) == 1 +def test_dual_plugin_discover_hooks_called_once(monkeypatch): + """Verify the hasattr guard prevents discover_hooks from running twice in dual inheritance.""" + import strands.plugins.plugin as plugin_mod + + call_count = 0 + original = plugin_mod.discover_hooks + + def counting_discover_hooks(instance, plugin_name): + nonlocal call_count + call_count += 1 + return original(instance, plugin_name) + + monkeypatch.setattr(plugin_mod, "discover_hooks", counting_discover_hooks) + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + DualPlugin() + # Plugin.__init__ calls discover_hooks once; MultiAgentPlugin.__init__ skips due to hasattr guard + assert call_count == 1 + + def test_dual_plugin_has_both_init_methods(mock_agent, mock_orchestrator): agent_init_called = False multi_agent_init_called = False From 1d79ed061e5481d5c88ebe49cb1806bd3703c6dc Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Wed, 13 May 2026 14:01:36 -0400 Subject: [PATCH 9/9] fix: remove unused TypeVar import and variable --- src/strands/plugins/_discovery.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/strands/plugins/_discovery.py b/src/strands/plugins/_discovery.py index 646cd1c1a..eda955030 100644 --- a/src/strands/plugins/_discovery.py +++ b/src/strands/plugins/_discovery.py @@ -8,7 +8,7 @@ import inspect import logging from collections.abc import Awaitable, Callable -from typing import Any, TypeVar, cast +from typing import Any, cast from .._async import run_async from ..hooks.registry import HookCallback @@ -16,8 +16,6 @@ logger = logging.getLogger(__name__) -T = TypeVar("T") - def _discover_methods(instance: object, plugin_name: str, predicate: Callable[[object], bool], label: str) -> list[Any]: """Scan an instance's class hierarchy for methods matching a predicate.