Skip to content

Commit

Permalink
max reconnect attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikita Kharlov committed Sep 12, 2019
1 parent 20fcfc9 commit af6aead
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aio_pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ async def main():
query=kw
)

connection = connection_class(url, loop=loop)
connection = connection_class(url, loop=loop, **kwargs)
await connection.connect(timeout=timeout)
return connection

Expand Down
5 changes: 5 additions & 0 deletions aio_pika/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
pass


class MaxReconnectAttemptsReached(Exception):
pass


__all__ = (
'AMQPChannelError',
'AMQPConnectionError',
Expand All @@ -51,6 +55,7 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
'DuplicateConsumerTag',
'IncompatibleProtocolError',
'InvalidFrameError',
'MaxReconnectAttemptsReached',
'MessageProcessError',
'MethodNotImplemented',
'ProbableAuthenticationError',
Expand Down
28 changes: 27 additions & 1 deletion aio_pika/robust_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Type

from aiormq.connection import parse_bool, parse_int
from .exceptions import CONNECTION_EXCEPTIONS
from .exceptions import CONNECTION_EXCEPTIONS, MaxReconnectAttemptsReached
from .connection import Connection, connect, ConnectionType
from .tools import CallbackCollection
from .types import TimeoutType
Expand All @@ -29,6 +29,7 @@ class RobustConnection(Connection):

CHANNEL_CLASS = RobustChannel
KWARGS_TYPES = (
('max_reconnect_attempts', parse_int, '0'),
('reconnect_interval', parse_int, '5'),
('fail_fast', parse_bool, '1'),
)
Expand All @@ -41,8 +42,13 @@ def __init__(self, url, loop=None, **kwargs):
self.reconnect_interval = self.kwargs['reconnect_interval']
self.fail_fast = self.kwargs['fail_fast']

self._stop_future = self.loop.create_future()
self._stop_future.add_done_callback(self._on_stop)

self.__channels = set()
self._reconnect_attempt = None
self._on_reconnect_callbacks = CallbackCollection()
self._on_stop_callbacks = CallbackCollection()
self._closed = False

@property
Expand All @@ -63,11 +69,17 @@ def _on_connection_close(self, connection, closing, *args, **kwargs):

super()._on_connection_close(connection, closing)

if isinstance(closing.exception(), MaxReconnectAttemptsReached):
return

self.loop.call_later(
self.reconnect_interval,
lambda: self.loop.create_task(self.reconnect())
)

def _on_stop(self, future):
self._on_stop_callbacks(future.exception())

def add_reconnect_callback(self, callback: Callable[[], None]):
""" Add callback which will be called after reconnect.
Expand All @@ -76,6 +88,9 @@ def add_reconnect_callback(self, callback: Callable[[], None]):

self._on_reconnect_callbacks.add(callback)

def add_stop_callback(self, callback: Callable[[Exception], None]):
self._on_stop_callbacks.add(callback)

async def connect(self, timeout: TimeoutType = None):
while True:
try:
Expand All @@ -97,6 +112,16 @@ async def reconnect(self):
if self.is_closed:
return

if self.kwargs['max_reconnect_attempts'] > 0:
if self._reconnect_attempt is None:
self._reconnect_attempt = 1
else:
self._reconnect_attempt += 1

if self._reconnect_attempt > self.kwargs['max_reconnect_attempts']:
self._stop_future.set_exception(MaxReconnectAttemptsReached())
return

try:
await super().connect()
except CONNECTION_EXCEPTIONS:
Expand Down Expand Up @@ -124,6 +149,7 @@ def channel(self, channel_number: int = None,
return channel

async def _on_reconnect(self):
self._reconnect_attempt = None
for number, channel in self._channels.items():
try:
await channel.on_reconnect(self, number)
Expand Down
43 changes: 38 additions & 5 deletions tests/test_amqp_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiormq import ChannelLockedResource

from aio_pika import connect_robust, Message
from aio_pika.exceptions import MaxReconnectAttemptsReached
from aio_pika.robust_channel import RobustChannel
from aio_pika.robust_connection import RobustConnection
from aio_pika.robust_queue import RobustQueue
Expand All @@ -27,6 +28,7 @@ def __init__(self, *, loop, shost='127.0.0.1', sport,
self.src_port = sport
self.dst_host = dhost
self.dst_port = dport
self._run_task = None
self.connections = set()

async def _pipe(self, reader: asyncio.StreamReader,
Expand Down Expand Up @@ -54,14 +56,18 @@ async def handle_client(self, creader: asyncio.StreamReader,
])

async def start(self):
result = await asyncio.start_server(
self._run_task = await asyncio.start_server(
self.handle_client,
host=self.src_host,
port=self.src_port,
loop=self.loop,
)

return result
async def stop(self):
assert self._run_task is not None
self._run_task.close()
await self.disconnect()
self._run_task = None

async def disconnect(self):
tasks = list()
Expand All @@ -74,7 +80,8 @@ async def close(writer):
writer = self.connections.pop() # type: asyncio.StreamWriter
tasks.append(self.loop.create_task(close(writer)))

await asyncio.wait(tasks)
if tasks:
await asyncio.wait(tasks)


class TestCase(AMQPTestCase):
Expand All @@ -86,7 +93,7 @@ def get_unused_port() -> int:
sock.close()
return port

async def create_connection(self, cleanup=True):
async def create_connection(self, cleanup=True, max_reconnect_attempts=0):
self.proxy = Proxy(
dhost=AMQP_URL.host,
dport=AMQP_URL.port,
Expand All @@ -100,7 +107,11 @@ async def create_connection(self, cleanup=True):
self.proxy.src_host
).with_port(
self.proxy.src_port
).update_query(reconnect_interval=1)
).update_query(
reconnect_interval=1
).update_query(
max_reconnect_attempts=max_reconnect_attempts
)

client = await connect_robust(str(url), loop=self.loop)

Expand Down Expand Up @@ -212,6 +223,28 @@ async def reader():

assert len(shared) == 10

async def test_robust_reconnect_max_attempts(self):
client = await self.create_connection(max_reconnect_attempts=2)
self.assertIsInstance(client, RobustConnection)

first_close = asyncio.Future()
stopped = asyncio.Future()

def stop_callback(exc):
assert isinstance(exc, MaxReconnectAttemptsReached)
stopped.set_result(True)

def close_callback(f):
first_close.set_result(True)

client.add_stop_callback(stop_callback)
client.connection.closing.add_done_callback(close_callback)
await self.proxy.stop()
await first_close
# 1 interval before first try and 2 after attempts
await asyncio.wait_for(stopped,
timeout=client.reconnect_interval * 3 + 0.1)

async def test_channel_locked_resource2(self):
ch1 = await self.create_channel()
ch2 = await self.create_channel()
Expand Down

0 comments on commit af6aead

Please sign in to comment.