How it began
Disclaimer: Despite this being the outcome of a work issue all opinions are my own, my empoyer is not involved not in any way endorses this publication aside from an implicit authorization to publicly write about technical issues and solutions.
In a daily chat with a few of my colleagues at ShipHero, someone mentionned that an issue we were having (not relevant to this post), might be caused by sqlalchemy thread-local sessions being passed to another thread.
The lazy path
We could have done an audit of all our code base around connections. Matter of fact, we did, but only in the core. A more in depth check would have taken more time than realistically possible.
I suggested a compromise solution, we could add a little hack, that was by no means efficient but used on a node only might help us determine if this was the case.
The code
I have not worked in Python as my main language in years, just now I am returning to it so this might be rather rusty code (skip to the end for the full thing).
The outer wrapper
The general idea is to wrapp the scoped_session(session_factory)
object so, each time it instantiates a new thread-bound session, we get instead a wrapped object that has metadata about the thread it is bound to.
class SessionBoundaries:
"""A wrapper class for thread local sessions
This should help in following each session during their lifetime and err if the session is used
in a thread that is not the origin one
"""
def __init__(self, thread_session):
self._session = thread_session
def __call__(self, *args, **kwargs):
"""A replacement for session_factory()
A new thread local session is created by invoking scoped_session(SessionFactory)()
this implements the wrapper () to maintain behavior"""
new_sess = self._session(*args, **kwargs)
current_thread_id = threading.get_ident()
log = logging.getLogger( "SessionBoundaries.wrapper" )
return Wrapper(new_sess, current_thread_id, log)
def __getattr__(self, name):
"""Intercept every attribute access
Except for __call__ all others are relayed"""
return getattr(self.__session, name)
The inner wrapper
The object has very little in the way of logic it passe through every argument requested but will log a message if the session is used outside of their initial thread.
class Wrapper:
"""The actual inner wrapper class for the session object"""
def __init__(self, sub_s, t_id, log):
"""
:param sub_s: the subjacent connection
:param t_id: the id of this thread
"""
self._sub_s = sub_s
self._t_id = t_id
self._log = log
def __getattr__(self, name):
"""Intercept every attribute access"""
# get the current thread information
t_ident = threading.get_ident()
# log a warning if we are not in the original thread
if self.t_id != t_ident:
import traceback
stack = "".join(traceback.format_stack())
self._log.debug(f"boundaries crossed in {stack}")
# finally relay the attr access to the alchemy conn
return getattr(self.sub_s, name)
Usage
Usage is fairly simple:
# Create a session factory and a scoped_session.
session_factory = sessionmaker(bind=engine)
scoped_session_manager = SessionBoundaries(scoped_session(session_factory))
A sample test
This is a simple test checking it works, it requires a mysql connection.
# Pull the latest MySQL image (using MySQL 8.0 in this example)
docker pull mysql:8.0
# Run a new MySQL container:
# - Container name: mysql-test
# - Root password: my-secret-pw
# - Create a database named test_db
# - Map port 3306 on the container to 3306 on the host
docker run --name mysql-test \
-e MYSQL_ROOT_PASSWORD=my-secret-pw \
-e MYSQL_DATABASE=test_db \
-p 3306:3306 \
-d mysql:8.0
# Wait until MySQL is ready to accept connections.
echo "Waiting for MySQL to start..."
sleep 30 # This might not be enough in any system
echo "MySQL should now be running on localhost:3306"
And runs as a regular test.
With requirements:
SQLAlchemy==2.0.38
import logging
import sys
import unittest
import threading
from sqlalchemy import create_engine, Column, Integer, String, select
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
# Define the base and model.
Base = declarative_base()
class TheMostGenericModelNameEver(Base):
"""A very generic model to exemplify sqlalchemy usage"""
__tablename__ = 'the_most_generic_model_name_ever'
id = Column(Integer, primary_key=True)
name = Column(String(256), nullable=False)
class Wrapper:
"""The actual inner wrapper class for the session object"""
def __init__(self, sub_s, t_id, log):
"""
:param sub_s: the subjacent connection
:param t_id: the id of this thread
"""
self._sub_s = sub_s
self._t_id = t_id
self._log = log
def __getattr__(self, name):
"""Intercept every attribute access"""
# get the current thread information
t_ident = threading.get_ident()
# log a warning if we are not in the original thread
if self.t_id != t_ident:
import traceback
stack = "".join(traceback.format_stack())
self._log.debug(f"boundaries crossed in {stack}")
# finally relay the attr access to the alchemy conn
return getattr(self.sub_s, name)
class SessionBoundaries:
"""A wrapper class for thread local sessions
This should help in following each session during their lifetime and err if the session is used
in a thread that is not the origin one
"""
def __init__(self, thread_session):
self._session = thread_session
def __call__(self, *args, **kwargs):
"""A replacement for session_factory()
A new thread local session is created by invoking scoped_session(SessionFactory)()
this implements the wrapper () to maintain behavior"""
new_sess = self._session(*args, **kwargs)
current_thread_id = threading.get_ident()
log = logging.getLogger( "SessionBoundaries.wrapper" )
return Wrapper(new_sess, current_thread_id, log)
def __getattr__(self, name):
"""Intercept every attribute access
Except for __call__ all others are relayed"""
return getattr(self.__session, name)
class TestScopedSessionThreading(unittest.TestCase):
def test_scoped_session_with_threads_and_model(self):
log= logging.getLogger( "TestScopedSessionThreading.test_scoped_session_with_threads_and_model" )
# This test was trying to be accurate to the problematic project.
engine = create_engine(
"mysql+pymysql://root:my-secret-pw@localhost:3306/test_db",
pool_pre_ping=True, # Helps to ensure connections are still valid.
echo=True # Optional: echoes SQL for debugging.
)
# Create the table.
Base.metadata.create_all(engine)
# Create a session factory and a scoped_session.
session_factory = sessionmaker(bind=engine)
scoped_session_manager = SessionBoundaries(scoped_session(session_factory))
# Insert some records before starting threads.
main_session = scoped_session_manager()
try:
models = main_session.query(TheMostGenericModelNameEver).all()
for model in models:
main_session.delete(model)
except Exception as e:
log.debug(e)
main_session.add_all([TheMostGenericModelNameEver(name=f"Record {i}") for i in range(10)])
main_session.commit()
scoped_session_manager.remove() # Remove main thread session.
errors = [] # To collect errors from threads.
session = scoped_session_manager()
def worker(thread_id):
try:
# Use session.scalars() to directly obtain a list of MyModel instances.
stmt = select(TheMostGenericModelNameEver)
models = session.scalars(stmt).all()
if len(models) != 10:
errors.append(f"Thread {thread_id}: Expected 10 records, got {len(models)} q: {str(stmt)}")
# Optionally verify that the names are as expected.
expected_names = {f"Record {i}" for i in range(10)}
record_names = {model.name for model in models}
if record_names != expected_names:
errors.append(f"Thread {thread_id}: Record names mismatch. q: {str(stmt)}")
except Exception as e:
errors.append(f"Thread {thread_id}: Exception: {e} q: {str(stmt)}")
finally:
pass
threads = []
for i in range(3):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
if errors:
self.fail("Errors occurred in threads: " + "\n * ".join(errors))
else:
log.debug("All threads executed successfully.")
scoped_session_manager.remove()
if __name__ == '__main__':
logging.basicConfig( stream=sys.stderr )
logging.getLogger( "TestScopedSessionThreading.test_scoped_session_with_threads_and_model" ).setLevel( logging.DEBUG )
logging.getLogger( "SessionBoundaries.wrapper" ).setLevel( logging.DEBUG )
unittest.main()
Comments via 🦣
With an account on the Fediverse or Mastodon, you can respond to this post. Known replies are displayed below:
Credits on this comments implementation to Andreas Scherbaum.Note: This will load data from hachyderm.io.
Comments via 🦋