Skip to content

Commit d304d0a

Browse files
committed
[AIT-258] feat: add Realtime mutable message support
- Updated `ConnectionManager` and `MessageQueue` to process `PublishResult` during acknowledgments (ACK/NACK). - Extended `send_protocol_message` to return `PublishResult` for publish tracking. - Bumped default `protocol_version` to 5. - Added tests for message update, delete, append operations, and PublishResult handling.
1 parent 393693c commit d304d0a

6 files changed

Lines changed: 535 additions & 20 deletions

File tree

ably/realtime/connectionmanager.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from collections import deque
66
from datetime import datetime
7+
from itertools import zip_longest
78
from typing import TYPE_CHECKING
89

910
import httpx
@@ -13,6 +14,7 @@
1314
from ably.types.connectiondetails import ConnectionDetails
1415
from ably.types.connectionerrors import ConnectionErrors
1516
from ably.types.connectionstate import ConnectionEvent, ConnectionState, ConnectionStateChange
17+
from ably.types.operations import PublishResult
1618
from ably.types.tokendetails import TokenDetails
1719
from ably.util.eventemitter import EventEmitter
1820
from ably.util.exceptions import AblyException, IncompatibleClientIdException
@@ -29,7 +31,7 @@ class PendingMessage:
2931

3032
def __init__(self, message: dict):
3133
self.message = message
32-
self.future: asyncio.Future | None = None
34+
self.future: asyncio.Future[PublishResult] | None = None
3335
action = message.get('action')
3436

3537
# Messages that require acknowledgment: MESSAGE, PRESENCE, ANNOTATION, OBJECT
@@ -58,15 +60,22 @@ def count(self) -> int:
5860
"""Return the number of pending messages"""
5961
return len(self.messages)
6062

61-
def complete_messages(self, serial: int, count: int, err: AblyException | None = None) -> None:
63+
def complete_messages(
64+
self,
65+
serial: int,
66+
count: int,
67+
res: list[PublishResult] | None,
68+
err: AblyException | None = None
69+
) -> None:
6270
"""Complete messages based on serial and count from ACK/NACK
6371
6472
Args:
6573
serial: The msgSerial of the first message being acknowledged
6674
count: The number of messages being acknowledged
75+
res: List of PublishResult objects for each message acknowledged, or None if not available
6776
err: Error from NACK, or None for successful ACK
6877
"""
69-
log.debug(f'MessageQueue.complete_messages(): serial={serial}, count={count}, err={err}')
78+
log.debug(f'MessageQueue.complete_messages(): serial={serial}, count={count}, res={res}, err={err}')
7079

7180
if not self.messages:
7281
log.warning('MessageQueue.complete_messages(): called on empty queue')
@@ -87,12 +96,17 @@ def complete_messages(self, serial: int, count: int, err: AblyException | None =
8796
completed_messages = self.messages[:num_to_complete]
8897
self.messages = self.messages[num_to_complete:]
8998

90-
for msg in completed_messages:
99+
# Default res to empty list if None
100+
res_list = res if res is not None else []
101+
for (msg, publish_result) in zip_longest(completed_messages, res_list):
91102
if msg.future and not msg.future.done():
92103
if err:
93104
msg.future.set_exception(err)
94105
else:
95-
msg.future.set_result(None)
106+
# If publish_result is None, return empty PublishResult
107+
if publish_result is None:
108+
publish_result = PublishResult()
109+
msg.future.set_result(publish_result)
96110

97111
def complete_all_messages(self, err: AblyException) -> None:
98112
"""Complete all pending messages with an error"""
@@ -199,7 +213,7 @@ async def close_impl(self) -> None:
199213

200214
self.notify_state(ConnectionState.CLOSED)
201215

202-
async def send_protocol_message(self, protocol_message: dict) -> None:
216+
async def send_protocol_message(self, protocol_message: dict) -> PublishResult | None:
203217
"""Send a protocol message and optionally track it for acknowledgment
204218
205219
Args:
@@ -233,12 +247,14 @@ async def send_protocol_message(self, protocol_message: dict) -> None:
233247
if state_should_queue:
234248
self.queued_messages.appendleft(pending_message)
235249
if pending_message.ack_required:
236-
await pending_message.future
250+
return await pending_message.future
237251
return None
238252

239253
return await self._send_protocol_message_on_connected_state(pending_message)
240254

241-
async def _send_protocol_message_on_connected_state(self, pending_message: PendingMessage) -> None:
255+
async def _send_protocol_message_on_connected_state(
256+
self, pending_message: PendingMessage
257+
) -> PublishResult | None:
242258
if self.state == ConnectionState.CONNECTED and self.transport:
243259
# Add to pending queue before sending (for messages being resent from queue)
244260
if pending_message.ack_required and pending_message not in self.pending_message_queue.messages:
@@ -253,7 +269,7 @@ async def _send_protocol_message_on_connected_state(self, pending_message: Pendi
253269
AblyException("No active transport", 500, 50000)
254270
)
255271
if pending_message.ack_required:
256-
await pending_message.future
272+
return await pending_message.future
257273
return None
258274

259275
def send_queued_messages(self) -> None:
@@ -449,15 +465,18 @@ def on_heartbeat(self, id: str | None) -> None:
449465
self.__ping_future.set_result(None)
450466
self.__ping_future = None
451467

452-
def on_ack(self, serial: int, count: int) -> None:
468+
def on_ack(
469+
self, serial: int, count: int, res: list[PublishResult] | None
470+
) -> None:
453471
"""Handle ACK protocol message from server
454472
455473
Args:
456474
serial: The msgSerial of the first message being acknowledged
457475
count: The number of messages being acknowledged
476+
res: List of PublishResult objects for each message acknowledged, or None if not available
458477
"""
459-
log.debug(f'ConnectionManager.on_ack(): serial={serial}, count={count}')
460-
self.pending_message_queue.complete_messages(serial, count)
478+
log.debug(f'ConnectionManager.on_ack(): serial={serial}, count={count}, res={res}')
479+
self.pending_message_queue.complete_messages(serial, count, res)
461480

462481
def on_nack(self, serial: int, count: int, err: AblyException | None) -> None:
463482
"""Handle NACK protocol message from server
@@ -471,7 +490,7 @@ def on_nack(self, serial: int, count: int, err: AblyException | None) -> None:
471490
err = AblyException('Unable to send message; channel not responding', 50001, 500)
472491

473492
log.error(f'ConnectionManager.on_nack(): serial={serial}, count={count}, err={err}')
474-
self.pending_message_queue.complete_messages(serial, count, err)
493+
self.pending_message_queue.complete_messages(serial, count, None, err)
475494

476495
def deactivate_transport(self, reason: AblyException | None = None):
477496
# RTN19a: Before disconnecting, requeue any pending messages

ably/realtime/realtime_channel.py

Lines changed: 207 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from ably.transport.websockettransport import ProtocolMessageAction
1111
from ably.types.channelstate import ChannelState, ChannelStateChange
1212
from ably.types.flags import Flag, has_flag
13-
from ably.types.message import Message
13+
from ably.types.message import Message, MessageAction, MessageVersion
1414
from ably.types.mixins import DecodingContext
15+
from ably.types.operations import MessageOperation, PublishResult, UpdateDeleteResult
1516
from ably.types.presence import PresenceMessage
1617
from ably.util.eventemitter import EventEmitter
1718
from ably.util.exceptions import AblyException, IncompatibleClientIdException
@@ -390,7 +391,7 @@ def unsubscribe(self, *args) -> None:
390391
self.__message_emitter.off(listener)
391392

392393
# RTL6
393-
async def publish(self, *args, **kwargs) -> None:
394+
async def publish(self, *args, **kwargs) -> PublishResult:
394395
"""Publish a message or messages on this channel
395396
396397
Publishes a single message or an array of messages to the channel.
@@ -490,7 +491,7 @@ async def publish(self, *args, **kwargs) -> None:
490491
}
491492

492493
# RTL6b: Await acknowledgment from server
493-
await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message)
494+
return await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message)
494495

495496
def _throw_if_unpublishable_state(self) -> None:
496497
"""Check if the channel and connection are in a state that allows publishing
@@ -522,6 +523,200 @@ def _throw_if_unpublishable_state(self) -> None:
522523
90001,
523524
)
524525

526+
async def _send_update(self, message: Message, action: MessageAction,
527+
operation: MessageOperation = None) -> UpdateDeleteResult:
528+
"""Internal method to send update/delete/append operations via websocket.
529+
530+
Parameters
531+
----------
532+
message : Message
533+
Message object with serial field required
534+
action : MessageAction
535+
The action type (MESSAGE_UPDATE, MESSAGE_DELETE, MESSAGE_APPEND)
536+
operation : MessageOperation, optional
537+
Operation metadata (description, metadata)
538+
539+
Returns
540+
-------
541+
UpdateDeleteResult
542+
Result containing version serial of the operation
543+
544+
Raises
545+
------
546+
AblyException
547+
If message serial is missing or connection/channel state prevents operation
548+
"""
549+
# Check message has serial
550+
if not message.serial:
551+
raise AblyException(
552+
"Message serial is required for update/delete/append operations",
553+
400,
554+
40000
555+
)
556+
557+
# Check connection and channel state
558+
self._throw_if_unpublishable_state()
559+
560+
# Create version from operation if provided
561+
if not operation:
562+
version = None
563+
else:
564+
version = MessageVersion(
565+
client_id=operation.client_id,
566+
description=operation.description,
567+
metadata=operation.metadata
568+
)
569+
570+
# Create a new message with the operation fields
571+
update_message = Message(
572+
name=message.name,
573+
data=message.data,
574+
client_id=message.client_id,
575+
serial=message.serial,
576+
action=action,
577+
version=version,
578+
)
579+
580+
# Encrypt if needed
581+
if self.cipher:
582+
update_message.encrypt(self.cipher)
583+
584+
# Convert to dict representation
585+
msg_dict = update_message.as_dict(binary=self.ably.options.use_binary_protocol)
586+
587+
log.info(
588+
f'RealtimeChannel._send_update(): sending {action.name} message; '
589+
f'channel = {self.name}, state = {self.state}, serial = {message.serial}'
590+
)
591+
592+
# Send protocol message
593+
protocol_message = {
594+
"action": ProtocolMessageAction.MESSAGE,
595+
"channel": self.name,
596+
"messages": [msg_dict],
597+
}
598+
599+
# Send and await acknowledgment
600+
result = await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message)
601+
602+
# Return UpdateDeleteResult - we don't have version_serial from the result yet
603+
# The server will send ACK with the result
604+
if result and hasattr(result, 'serials') and result.serials:
605+
return UpdateDeleteResult(version_serial=result.serials[0])
606+
return UpdateDeleteResult()
607+
608+
async def update_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult:
609+
"""Updates an existing message on this channel.
610+
611+
Parameters
612+
----------
613+
message : Message
614+
Message object to update. Must have a serial field.
615+
operation : MessageOperation, optional
616+
Optional MessageOperation containing description and metadata for the update.
617+
618+
Returns
619+
-------
620+
UpdateDeleteResult
621+
Result containing the version serial of the updated message.
622+
623+
Raises
624+
------
625+
AblyException
626+
If message serial is missing or connection/channel state prevents the update
627+
"""
628+
return await self._send_update(message, MessageAction.MESSAGE_UPDATE, operation)
629+
630+
async def delete_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult:
631+
"""Deletes a message on this channel.
632+
633+
Parameters
634+
----------
635+
message : Message
636+
Message object to delete. Must have a serial field.
637+
operation : MessageOperation, optional
638+
Optional MessageOperation containing description and metadata for the delete.
639+
640+
Returns
641+
-------
642+
UpdateDeleteResult
643+
Result containing the version serial of the deleted message.
644+
645+
Raises
646+
------
647+
AblyException
648+
If message serial is missing or connection/channel state prevents the delete
649+
"""
650+
return await self._send_update(message, MessageAction.MESSAGE_DELETE, operation)
651+
652+
async def append_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult:
653+
"""Appends data to an existing message on this channel.
654+
655+
Parameters
656+
----------
657+
message : Message
658+
Message object with data to append. Must have a serial field.
659+
operation : MessageOperation, optional
660+
Optional MessageOperation containing description and metadata for the append.
661+
662+
Returns
663+
-------
664+
UpdateDeleteResult
665+
Result containing the version serial of the appended message.
666+
667+
Raises
668+
------
669+
AblyException
670+
If message serial is missing or connection/channel state prevents the append
671+
"""
672+
return await self._send_update(message, MessageAction.MESSAGE_APPEND, operation)
673+
674+
async def get_message(self, serial_or_message, timeout=None):
675+
"""Retrieves a single message by its serial using the REST API.
676+
677+
Parameters
678+
----------
679+
serial_or_message : str or Message
680+
Either a string serial or a Message object with a serial field.
681+
timeout : float, optional
682+
Timeout for the request.
683+
684+
Returns
685+
-------
686+
Message
687+
Message object for the requested serial.
688+
689+
Raises
690+
------
691+
AblyException
692+
If the serial is missing or the message cannot be retrieved.
693+
"""
694+
# Delegate to parent Channel (REST) implementation
695+
return await Channel.get_message(self, serial_or_message, timeout=timeout)
696+
697+
async def get_message_versions(self, serial_or_message, params=None):
698+
"""Retrieves version history for a message using the REST API.
699+
700+
Parameters
701+
----------
702+
serial_or_message : str or Message
703+
Either a string serial or a Message object with a serial field.
704+
params : dict, optional
705+
Optional dict of query parameters for pagination.
706+
707+
Returns
708+
-------
709+
PaginatedResult
710+
PaginatedResult containing Message objects representing each version.
711+
712+
Raises
713+
------
714+
AblyException
715+
If the serial is missing or versions cannot be retrieved.
716+
"""
717+
# Delegate to parent Channel (REST) implementation
718+
return await Channel.get_message_versions(self, serial_or_message, params=params)
719+
525720
def _on_message(self, proto_msg: dict) -> None:
526721
action = proto_msg.get('action')
527722
# RTL4c1
@@ -766,7 +961,7 @@ class Channels(RestChannels):
766961
"""
767962

768963
# RTS3
769-
def get(self, name: str, options: ChannelOptions | None = None) -> RealtimeChannel:
964+
def get(self, name: str, options: ChannelOptions | None = None, **kwargs) -> RealtimeChannel:
770965
"""Creates a new RealtimeChannel object, or returns the existing channel object.
771966
772967
Parameters
@@ -776,7 +971,15 @@ def get(self, name: str, options: ChannelOptions | None = None) -> RealtimeChann
776971
Channel name
777972
options: ChannelOptions or dict, optional
778973
Channel options for the channel
974+
**kwargs:
975+
Additional keyword arguments to create ChannelOptions (e.g., cipher, params)
779976
"""
977+
# Convert kwargs to ChannelOptions if provided
978+
if kwargs and not options:
979+
options = ChannelOptions(**kwargs)
980+
elif options and isinstance(options, dict):
981+
options = ChannelOptions.from_dict(options)
982+
780983
if name not in self.__all:
781984
channel = self.__all[name] = RealtimeChannel(self.__ably, name, options)
782985
else:

ably/transport/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
class Defaults:
2-
protocol_version = "2"
2+
protocol_version = "5"
33
fallback_hosts = [
44
"a.ably-realtime.com",
55
"b.ably-realtime.com",

0 commit comments

Comments
 (0)