Part 8: A2A (Agent-to-Agent) Protocol Implementation

8 June 2025 · netologist · 8 min, 1645 words ·

Why A2A Communication?

Agent-to-agent communication enables:

MQTT-Based A2A Protocol

# a2a/protocol.py
import asyncio
import json
import logging
from typing import Dict, Any, Optional, Callable, List
from datetime import datetime, timezone
import uuid
from dataclasses import dataclass, asdict
from enum import Enum
import paho.mqtt.client as mqtt
from config import config

class MessageType(Enum):
    REQUEST = "request"
    RESPONSE = "response"
    BROADCAST = "broadcast"
    HEARTBEAT = "heartbeat"
    ERROR = "error"

class MessagePriority(Enum):
    LOW = 1
    NORMAL = 2
    HIGH = 3
    URGENT = 4

@dataclass
class A2AMessage:
    id: str
    sender_id: str
    recipient_id: Optional[str]  # None for broadcast
    message_type: MessageType
    priority: MessagePriority
    timestamp: datetime
    payload: Dict[str, Any]
    correlation_id: Optional[str] = None  # For request-response pairs
    expires_at: Optional[datetime] = None
    
    def to_json(self) -> str:
        """Convert message to JSON string"""
        data = asdict(self)
        data['timestamp'] = self.timestamp.isoformat()
        data['message_type'] = self.message_type.value
        data['priority'] = self.priority.value
        if self.expires_at:
            data['expires_at'] = self.expires_at.isoformat()
        return json.dumps(data)
    
    @classmethod
    def from_json(cls, json_str: str) -> 'A2AMessage':
        """Create message from JSON string"""
        data = json.loads(json_str)
        data['timestamp'] = datetime.fromisoformat(data['timestamp'])
        data['message_type'] = MessageType(data['message_type'])
        data['priority'] = MessagePriority(data['priority'])
        if data.get('expires_at'):
            data['expires_at'] = datetime.fromisoformat(data['expires_at'])
        return cls(**data)

class A2AProtocol:
    def __init__(self, agent_id: str):
        self.agent_id = agent_id
        self.client = mqtt.Client()
        self.is_connected = False
        
        # Message handlers
        self.message_handlers: Dict[MessageType, Callable] = {}
        self.request_handlers: Dict[str, Callable] = {}
        
        # Pending requests (for tracking responses)
        self.pending_requests: Dict[str, asyncio.Future] = {}
        
        # Known agents
        self.known_agents: Dict[str, Dict[str, Any]] = {}
        
        # Setup MQTT client
        self._setup_mqtt_client()
        
        # Topic structure: a2a/{recipient_id}/{message_type}
        # Broadcast topic: a2a/broadcast/{message_type}
        self.base_topic = "a2a"
        
    def _setup_mqtt_client(self):
        """Setup MQTT client callbacks"""
        
        def on_connect(client, userdata, flags, rc):
            if rc == 0:
                self.is_connected = True
                logging.info(f"Agent {self.agent_id} connected to MQTT broker")
                
                # Subscribe to our topics
                topics = [
                    f"{self.base_topic}/{self.agent_id}/+",  # Direct messages
                    f"{self.base_topic}/broadcast/+",        # Broadcast messages
                ]
                
                for topic in topics:
                    client.subscribe(topic)
                    logging.info(f"Subscribed to {topic}")
                    
                # Send heartbeat
                asyncio.create_task(self._send_heartbeat())
                
            else:
                logging.error(f"Failed to connect to MQTT broker: {rc}")
        
        def on_message(client, userdata, msg):
            try:
                message = A2AMessage.from_json(msg.payload.decode())
                asyncio.create_task(self._handle_message(message))
            except Exception as e:
                logging.error(f"Failed to process A2A message: {e}")
        
        def on_disconnect(client, userdata, rc):
            self.is_connected = False
            logging.info(f"Agent {self.agent_id} disconnected from MQTT broker")
        
        self.client.on_connect = on_connect
        self.client.on_message = on_message
        self.client.on_disconnect = on_disconnect
    
    async def connect(self):
        """Connect to MQTT broker"""
        try:
            self.client.connect(config.MQTT_BROKER, config.MQTT_PORT, 60)
            self.client.loop_start()
            
            # Wait for connection
            while not self.is_connected:
                await asyncio.sleep(0.1)
                
        except Exception as e:
            logging.error(f"Failed to connect to MQTT broker: {e}")
            raise
    
    async def disconnect(self):
        """Disconnect from MQTT broker"""
        self.client.loop_stop()
        self.client.disconnect()
    
    def register_message_handler(self, message_type: MessageType, handler: Callable):
        """Register handler for specific message type"""
        self.message_handlers[message_type] = handler
    
    def register_request_handler(self, request_type: str, handler: Callable):
        """Register handler for specific request type"""
        self.request_handlers[request_type] = handler
    
    async def send_message(self, message: A2AMessage):
        """Send A2A message"""
        if not self.is_connected:
            raise RuntimeError("Not connected to MQTT broker")
        
        # Determine topic
        if message.recipient_id:
            topic = f"{self.base_topic}/{message.recipient_id}/{message.message_type.value}"
        else:
            topic = f"{self.base_topic}/broadcast/{message.message_type.value}"
        
        # Publish message
        result = self.client.publish(topic, message.to_json())
        
        if result.rc != mqtt.MQTT_ERR_SUCCESS:
            raise RuntimeError(f"Failed to send message: {result.rc}")
        
        logging.debug(f"Sent message {message.id} to {topic}")
    
    async def send_request(
        self, 
        recipient_id: str, 
        request_type: str, 
        payload: Dict[str, Any],
        timeout: float = 30.0
    ) -> Dict[str, Any]:
        """Send request and wait for response"""
        
        # Create request message
        message = A2AMessage(
            id=str(uuid.uuid4()),
            sender_id=self.agent_id,
            recipient_id=recipient_id,
            message_type=MessageType.REQUEST,
            priority=MessagePriority.NORMAL,
            timestamp=datetime.now(timezone.utc),
            payload={
                "request_type": request_type,
                **payload
            },
            correlation_id=str(uuid.uuid4())
        )
        
        # Create future for response
        response_future = asyncio.Future()
        self.pending_requests[message.correlation_id] = response_future
        
        try:
            # Send request
            await self.send_message(message)
            
            # Wait for response
            response = await asyncio.wait_for(response_future, timeout=timeout)
            return response
            
        except asyncio.TimeoutError:
            # Remove pending request
            self.pending_requests.pop(message.correlation_id, None)
            raise RuntimeError(f"Request timeout: {request_type}")
        
        except Exception as e:
            # Remove pending request
            self.pending_requests.pop(message.correlation_id, None)
            raise
    
    async def send_response(
        self, 
        request_message: A2AMessage, 
        response_payload: Dict[str, Any]
    ):
        """Send response to a request"""
        
        response = A2AMessage(
            id=str(uuid.uuid4()),
            sender_id=self.agent_id,
            recipient_id=request_message.sender_id,
            message_type=MessageType.RESPONSE,
            priority=request_message.priority,
            timestamp=datetime.now(timezone.utc),
            payload=response_payload,
            correlation_id=request_message.correlation_id
        )
        
        await self.send_message(response)
    
    async def broadcast_message(
        self, 
        message_type: str, 
        payload: Dict[str, Any],
        priority: MessagePriority = MessagePriority.NORMAL
    ):
        """Broadcast message to all agents"""
        
        message = A2AMessage(
            id=str(uuid.uuid4()),
            sender_id=self.agent_id,
            recipient_id=None,
            message_type=MessageType.BROADCAST,
            priority=priority,
            timestamp=datetime.now(timezone.utc),
            payload={
                "broadcast_type": message_type,
                **payload
            }
        )
        
        await self.send_message(message)
    
    async def _handle_message(self, message: A2AMessage):
        """Handle incoming A2A message"""
        
        # Check if message is expired
        if message.expires_at and datetime.now(timezone.utc) > message.expires_at:
            logging.warning(f"Received expired message {message.id}")
            return
        
        # Handle response messages
        if message.message_type == MessageType.RESPONSE:
            await self._handle_response(message)
            return
        
        # Handle heartbeat messages
        if message.message_type == MessageType.HEARTBEAT:
            await self._handle_heartbeat(message)
            return
        
        # Handle request messages
        if message.message_type == MessageType.REQUEST:
            await self._handle_request(message)
            return
        
        # Handle broadcast messages
        if message.message_type == MessageType.BROADCAST:
            await self._handle_broadcast(message)
            return
        
        # Handle other message types
        if message.message_type in self.message_handlers:
            try:
                await self.message_handlers[message.message_type](message)
            except Exception as e:
                logging.error(f"Message handler failed: {e}")
    
    async def _handle_response(self, message: A2AMessage):
        """Handle response message"""
        correlation_id = message.correlation_id
        
        if correlation_id in self.pending_requests:
            future = self.pending_requests.pop(correlation_id)
            if not future.done():
                future.set_result(message.payload)
    
    async def _handle_request(self, message: A2AMessage):
        """Handle request message"""
        request_type = message.payload.get("request_type")
        
        if request_type in self.request_handlers:
            try:
                # Call request handler
                response_payload = await self.request_handlers[request_type](message.payload)
                
                # Send response
                await self.send_response(message, response_payload)
                
            except Exception as e:
                logging.error(f"Request handler failed: {e}")
                
                # Send error response
                error_payload = {
                    "error": str(e),
                    "request_type": request_type
                }
                await self.send_response(message, error_payload)
        else:
            logging.warning(f"No handler for request type: {request_type}")
    
    async def _handle_broadcast(self, message: A2AMessage):
        """Handle broadcast message"""
        broadcast_type = message.payload.get("broadcast_type")
        
        # Update known agents list
        if broadcast_type == "agent_announcement":
            agent_info = message.payload.get("agent_info", {})
            self.known_agents[message.sender_id] = {
                **agent_info,
                "last_seen": datetime.now(timezone.utc)
            }
            logging.info(f"Discovered agent: {message.sender_id}")
        
        # Handle other broadcast types with registered handlers
        handler_key = f"broadcast_{broadcast_type}"
        if handler_key in self.request_handlers:
            try:
                await self.request_handlers[handler_key](message.payload)
            except Exception as e:
                logging.error(f"Broadcast handler failed: {e}")
    
    async def _handle_heartbeat(self, message: A2AMessage):
        """Handle heartbeat message"""
        # Update agent's last seen time
        if message.sender_id in self.known_agents:
            self.known_agents[message.sender_id]["last_seen"] = datetime.now(timezone.utc)
    
    async def _send_heartbeat(self):
        """Send periodic heartbeat"""
        while self.is_connected:
            try:
                heartbeat = A2AMessage(
                    id=str(uuid.uuid4()),
                    sender_id=self.agent_id,
                    recipient_id=None,
                    message_type=MessageType.HEARTBEAT,
                    priority=MessagePriority.LOW,
                    timestamp=datetime.now(timezone.utc),
                    payload={
                        "status": "active",
                        "capabilities": await self._get_capabilities()
                    }
                )
                
                await self.send_message(heartbeat)
                await asyncio.sleep(30)  # Send heartbeat every 30 seconds
                
            except Exception as e:
                logging.error(f"Heartbeat failed: {e}")
                break
    
    async def _get_capabilities(self) -> List[str]:
        """Get agent capabilities"""
        return [
            "chat",
            "document_processing", 
            "voice_interaction",
            "knowledge_base_search",
            "memory_management"
        ]
    
    async def announce_agent(self):
        """Announce agent presence to network"""
        await self.broadcast_message(
            "agent_announcement",
            {
                "agent_info": {
                    "name": "AlexAI",
                    "version": "1.0.0",
                    "capabilities": await self._get_capabilities(),
                    "description": "AI Personal Assistant"
                }
            },
            priority=MessagePriority.HIGH
        )

# Agent coordination system
class AgentCoordinator:
    """Coordinates multiple AI agents"""
    
    def __init__(self, agent_id: str):
        self.agent_id = agent_id
        self.protocol = A2AProtocol(agent_id)
        self.specialized_agents = {}
        
        # Register coordination handlers
        self._register_handlers()
    
    def _register_handlers(self):
        """Register coordination message handlers"""
        
        # Task delegation
        self.protocol.register_request_handler(
            "delegate_task",
            self._handle_task_delegation
        )
        
        # Capability query
        self.protocol.register_request_handler(
            "query_capabilities",
            self._handle_capability_query
        )
        
        # Knowledge sharing
        self.protocol.register_request_handler(
            "share_knowledge",
            self._handle_knowledge_sharing
        )
        
        # Collaborative processing
        self.protocol.register_request_handler(
            "collaborative_processing",
            self._handle_collaborative_processing
        )
    
    async def start(self):
        """Start agent coordination"""
        await self.protocol.connect()
        await self.protocol.announce_agent()
        
        # Start periodic agent discovery
        asyncio.create_task(self._discover_agents())
    
    async def stop(self):
        """Stop agent coordination"""
        await self.protocol.disconnect()
    
    async def _discover_agents(self):
        """Discover other agents in the network"""
        while self.protocol.is_connected:
            try:
                # Query for active agents
                await self.protocol.broadcast_message(
                    "agent_discovery",
                    {"requesting_capabilities": True}
                )
                
                await asyncio.sleep(60)  # Discover every minute
                
            except Exception as e:
                logging.error(f"Agent discovery failed: {e}")
    
    async def delegate_task(
        self, 
        task_type: str, 
        task_data: Dict[str, Any], 
        preferred_agent: Optional[str] = None
    ) -> Dict[str, Any]:
        """Delegate task to appropriate agent"""
        
        # Find suitable agent
        target_agent = preferred_agent or await self._find_best_agent(task_type)
        
        if not target_agent:
            raise RuntimeError(f"No agent available for task type: {task_type}")
        
        # Send delegation request
        response = await self.protocol.send_request(
            recipient_id=target_agent,
            request_type="delegate_task",
            payload={
                "task_type": task_type,
                "task_data": task_data,
                "delegating_agent": self.agent_id
            }
        )
        
        return response
    
    async def _find_best_agent(self, task_type: str) -> Optional[str]:
        """Find the best agent for a specific task type"""
        
        # Simple capability matching
        for agent_id, agent_info in self.protocol.known_agents.items():
            capabilities = agent_info.get("capabilities", [])
            
            if task_type in capabilities:
                return agent_id
        
        return None
    
    async def _handle_task_delegation(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handle incoming task delegation"""
        task_type = payload["task_type"]
        task_data = payload["task_data"]
        
        # Execute task based on type
        if task_type == "document_analysis":
            return await self._execute_document_analysis(task_data)
        elif task_type == "knowledge_search":
            return await self._execute_knowledge_search(task_data)
        elif task_type == "voice_processing":
            return await self._execute_voice_processing(task_data)
        else:
            raise ValueError(f"Unsupported task type: {task_type}")
    
    async def _execute_document_analysis(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
        """Execute document analysis task"""
        # This would integrate with the MCP server
        from mcp.mcp_server import MCPServer
        
        mcp_server = MCPServer()
        result = await mcp_server._execute_tool("process_document", task_data)
        
        return {
            "status": "completed",
            "result": result,
            "agent_id": self.agent_id
        }
    
    async def _execute_knowledge_search(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
        """Execute knowledge search task"""
        from mcp.mcp_server import MCPServer
        
        mcp_server = MCPServer()
        result = await mcp_server._execute_tool("knowledge_base", {
            "action": "search",
            "query": task_data["query"]
        })
        
        return {
            "status": "completed",
            "result": result,
            "agent_id": self.agent_id
        }
    
    async def _execute_voice_processing(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
        """Execute voice processing task"""
        from mcp.mcp_server import MCPServer
        
        mcp_server = MCPServer()
        result = await mcp_server._execute_tool("voice_chat", task_data)
        
        return {
            "status": "completed",
            "result": result,
            "agent_id": self.agent_id
        }
    
    async def _handle_capability_query(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handle capability query"""
        return {
            "agent_id": self.agent_id,
            "capabilities": await self.protocol._get_capabilities(),
            "status": "active",
            "load": "normal"  # Could be dynamically calculated
        }
    
    async def _handle_knowledge_sharing(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handle knowledge sharing request"""
        knowledge_type = payload.get("knowledge_type")
        query = payload.get("query")
        
        # Share relevant knowledge from our knowledge base
        from mcp.mcp_server import MCPServer
        
        mcp_server = MCPServer()
        search_result = await mcp_server._execute_tool("knowledge_base", {
            "action": "search",
            "query": query
        })
        
        return {
            "shared_knowledge": search_result,
            "source_agent": self.agent_id,
            "knowledge_type": knowledge_type
        }
    
    async def _handle_collaborative_processing(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handle collaborative processing request"""
        task_id = payload.get("task_id")
        subtask = payload.get("subtask")
        
        # Process our part of the collaborative task
        # This is a simplified example
        
        return {
            "task_id": task_id,
            "subtask_result": f"Processed subtask: {subtask}",
            "processing_agent": self.agent_id,
            "status": "completed"
        }

# Usage example and integration
async def demo_a2a_protocol():
    """Demonstrate A2A protocol functionality"""
    
    # Create two agents for demonstration
    agent1 = AgentCoordinator("alexai-main")
    agent2 = AgentCoordinator("alexai-specialist")
    
    try:
        # Start both agents
        await agent1.start()
        await agent2.start()
        
        # Wait for discovery
        await asyncio.sleep(2)
        
        # Agent1 delegates a task to Agent2
        result = await agent1.delegate_task(
            task_type="knowledge_search",
            task_data={
                "query": "privacy-focused AI assistants"
            }
        )
        
        print("Task delegation result:", result)
        
        # Demonstrate broadcast communication
        await agent1.protocol.broadcast_message(
            "system_update",
            {
                "update_type": "capability_enhancement",
                "details": "Added new document processing capabilities"
            }
        )
        
        await asyncio.sleep(1)
        
    finally:
        # Clean up
        await agent1.stop()
        await agent2.stop()

if __name__ == "__main__":
    asyncio.run(demo_a2a_protocol())

Why this A2A architecture?