Skip to content

Commit

Permalink
Fix CassandraBackend error in threads or gevent pool (#6147)
Browse files Browse the repository at this point in the history
* Fix CassandraBackend error in threads or gevent pool
        * remove CassandraBackend.process_cleanup

* Add test case

* Add test case

* Add comments test_as_uri

Co-authored-by: baixue <baixue@wecash.net>
  • Loading branch information
baixuexue123 and baixue committed Jun 21, 2020
1 parent 877f4bc commit c5843a5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 46 deletions.
35 changes: 18 additions & 17 deletions celery/backends/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import, unicode_literals

import sys
import threading

from celery import states
from celery.exceptions import ImproperlyConfigured
Expand All @@ -14,6 +15,7 @@
import cassandra
import cassandra.auth
import cassandra.cluster
import cassandra.query
except ImportError: # pragma: no cover
cassandra = None # noqa

Expand Down Expand Up @@ -123,32 +125,29 @@ def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None,
raise ImproperlyConfigured(E_NO_SUCH_CASSANDRA_AUTH_PROVIDER)
self.auth_provider = auth_provider_class(**auth_kwargs)

self._connection = None
self._cluster = None
self._session = None
self._write_stmt = None
self._read_stmt = None
self._make_stmt = None

def process_cleanup(self):
if self._connection is not None:
self._connection.shutdown() # also shuts down _session
self._connection = None
self._session = None
self._lock = threading.RLock()

def _get_connection(self, write=False):
"""Prepare the connection for action.
Arguments:
write (bool): are we a writer?
"""
if self._connection is not None:
if self._session is not None:
return
self._lock.acquire()
try:
self._connection = cassandra.cluster.Cluster(
if self._session is not None:
return
self._cluster = cassandra.cluster.Cluster(
self.servers, port=self.port,
auth_provider=self.auth_provider,
**self.cassandra_options)
self._session = self._connection.connect(self.keyspace)
self._session = self._cluster.connect(self.keyspace)

# We're forced to do concatenation below, as formatting would
# blow up on superficial %s that'll be processed by Cassandra
Expand All @@ -172,25 +171,27 @@ def _get_connection(self, write=False):
# Anyway; if you're doing anything critical, you should
# have created this table in advance, in which case
# this query will be a no-op (AlreadyExists)
self._make_stmt = cassandra.query.SimpleStatement(
make_stmt = cassandra.query.SimpleStatement(
Q_CREATE_RESULT_TABLE.format(table=self.table),
)
self._make_stmt.consistency_level = self.write_consistency
make_stmt.consistency_level = self.write_consistency

try:
self._session.execute(self._make_stmt)
self._session.execute(make_stmt)
except cassandra.AlreadyExists:
pass

except cassandra.OperationTimedOut:
# a heavily loaded or gone Cassandra cluster failed to respond.
# leave this class in a consistent state
if self._connection is not None:
self._connection.shutdown() # also shuts down _session
if self._cluster is not None:
self._cluster.shutdown() # also shuts down _session

self._connection = None
self._cluster = None
self._session = None
raise # we did fail after all - reraise
finally:
self._lock.release()

def _store_result(self, task_id, result, state,
traceback=None, request=None, **kwargs):
Expand Down
87 changes: 58 additions & 29 deletions t/unit/backends/test_cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from celery.exceptions import ImproperlyConfigured
from celery.utils.objects import Bunch

CASSANDRA_MODULES = ['cassandra', 'cassandra.auth', 'cassandra.cluster']
CASSANDRA_MODULES = [
'cassandra',
'cassandra.auth',
'cassandra.cluster',
'cassandra.query',
]


@mock.module(*CASSANDRA_MODULES)
Expand Down Expand Up @@ -66,7 +71,6 @@ def test_get_task_meta_for(self, *modules):
mod.cassandra = Mock()

x = mod.CassandraBackend(app=self.app)
x._connection = True
session = x._session = Mock()
execute = session.execute = Mock()
result_set = Mock()
Expand All @@ -83,24 +87,24 @@ def test_get_task_meta_for(self, *modules):
meta = x._get_task_meta_for('task_id')
assert meta['status'] == states.PENDING

def test_as_uri(self):
# Just ensure as_uri works properly
from celery.backends import cassandra as mod
mod.cassandra = Mock()

x = mod.CassandraBackend(app=self.app)
x.as_uri()
x.as_uri(include_password=False)

def test_store_result(self, *modules):
from celery.backends import cassandra as mod
mod.cassandra = Mock()

x = mod.CassandraBackend(app=self.app)
x._connection = True
session = x._session = Mock()
session.execute = Mock()
x._store_result('task_id', 'result', states.SUCCESS)

def test_process_cleanup(self, *modules):
from celery.backends import cassandra as mod
x = mod.CassandraBackend(app=self.app)
x.process_cleanup()

assert x._connection is None
assert x._session is None

def test_timeouting_cluster(self):
# Tests behavior when Cluster.connect raises
# cassandra.OperationTimedOut.
Expand Down Expand Up @@ -128,40 +132,65 @@ def shutdown(self):

with pytest.raises(OTOExc):
x._store_result('task_id', 'result', states.SUCCESS)
assert x._connection is None
assert x._cluster is None
assert x._session is None

x.process_cleanup() # shouldn't raise

def test_please_free_memory(self):
# Ensure that Cluster object IS shut down.
def test_create_result_table(self):
# Tests behavior when session.execute raises
# cassandra.AlreadyExists.
from celery.backends import cassandra as mod

class RAMHoggingCluster(object):
class OTOExc(Exception):
pass

objects_alive = 0
class FaultySession(object):
def __init__(self, *args, **kwargs):
pass

def execute(self, *args, **kwargs):
raise OTOExc()

class DummyCluster(object):

def __init__(self, *args, **kwargs):
pass

def connect(self, *args, **kwargs):
RAMHoggingCluster.objects_alive += 1
return Mock()

def shutdown(self):
RAMHoggingCluster.objects_alive -= 1
return FaultySession()

mod.cassandra = Mock()
mod.cassandra.cluster = Mock()
mod.cassandra.cluster.Cluster = DummyCluster
mod.cassandra.AlreadyExists = OTOExc

x = mod.CassandraBackend(app=self.app)
x._get_connection(write=True)
assert x._session is not None

def test_init_session(self):
# Tests behavior when Cluster.connect works properly
from celery.backends import cassandra as mod

class DummyCluster(object):

def __init__(self, *args, **kwargs):
pass

def connect(self, *args, **kwargs):
return Mock()

mod.cassandra = Mock()
mod.cassandra.cluster = Mock()
mod.cassandra.cluster.Cluster = RAMHoggingCluster
mod.cassandra.cluster.Cluster = DummyCluster

for x in range(0, 10):
x = mod.CassandraBackend(app=self.app)
x._store_result('task_id', 'result', states.SUCCESS)
x.process_cleanup()
x = mod.CassandraBackend(app=self.app)
assert x._session is None
x._get_connection(write=True)
assert x._session is not None

assert RAMHoggingCluster.objects_alive == 0
s = x._session
x._get_connection()
assert s is x._session

def test_auth_provider(self):
# Ensure valid auth_provider works properly, and invalid one raises
Expand Down

0 comments on commit c5843a5

Please sign in to comment.