Context manager cookbook.

Non exhaustive list of examples with with statement

GitHub doesn't render large Jupyter Notebooks, so just in case, here is an nbviewer link to the notebook.

In [1]:
from contextlib import contextmanager

Change print color

In [2]:
@contextmanager
def print_blue():
    print('\033[34m', end='')
    yield
    print('\033[39m', end='')

with print_blue():
    print('Changes color in context')

print('Outside the context with default color')
Changes color in context
Outside the context with default color

Silently ignore error

Intentionally, suppress expected error. This approach reduces visual noise of try/except

In [3]:
from contextlib import suppress
import os
with suppress(FileNotFoundError):
    os.remove('file.txt')

compared

Logging level change only in context

In [4]:
import logging

@contextmanager
def debug_logging(logger_name: str, level: int):
    logger = logging.getLogger(logger_name)
    old_level = logger.getEffectiveLevel()
    logger.setLevel(level)
    try:
        yield logger
    finally:
        logger.setLevel(old_level)

with debug_logging('my-logger', logging.DEBUG) as logger:
    logger.debug('This will be printed')

logging\
.getLogger('my-logger')\
.info('This wont be logged because default level is WARNING')

Tag maker

In [5]:
@contextmanager
def tag(name):
    print(f'<{name}>', end='')
    yield
    print(f'</{name}>', end='')

with tag('header'):
    print('Tag body', end='')
<header>Tag body</header>

Indenter

In [6]:
class Indenter:
    def __init__(self):
        self.level = 0

    def __enter__(self):
        self.level += 1
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.level -= 1

    def print(self, text):
        print('    ' * (self.level-1) + text)

with Indenter() as indenter:
    indenter.print('def mimic_python_syntax():')
    with indenter:
        indenter.print('s = "Hello World"')
        indenter.print('print(s)')

    indenter.print('\nmimic_python_syntax()')
def mimic_python_syntax():
    s = "Hello World"
    print(s)

mimic_python_syntax()

List oprations with transaction

Make copy of input list items, work on copy. If there is no error, replace input list items with items from list used in context.

In [7]:
@contextmanager
def list_transaction(list_: list):
    working = list(list_)
    yield working
    list_[:] = working
In [8]:
items = [1,2,3]

with list_transaction(items) as working:
    working.append(4)
    raise RuntimeError()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-e92364a7c13f> in <module>
      3 with list_transaction(items) as working:
      4     working.append(4)
----> 5     raise RuntimeError()

RuntimeError: 
In [ ]:
items
In [ ]:
with list_transaction(items) as working:
    working.append(4)

print(items)

Database transation

Rolback changes if error occurs in context

In [9]:
class Transaction:
    def __init__(self, connection):
        self.connection = connection
    def __enter__(self):
        return self.connection

    def __exit__(self, err_type, err_value, err_traceback):
        if err_type:
            self.connection.rollback()
        else:
            self.connection.commit()
In [10]:
import sqlite3

connection = sqlite3.connect('')
In [11]:
with Transaction(connection) as t:
    t.execute("""
        CREATE TABLE Users
        (
        id INTEGER PRIMARY KEY,
        name TEXT NOT NULL
        )
        """)
    t.execute("""
        INSERT INTO Users
        (id, name) VALUES (1, 'Name 1')
        """)

    print(t.execute("""SELECT * FROM Users""").fetchall())
[(1, 'Name 1')]

Lazy connection

example from Python Cookbook 3rd Edition

In [12]:
from socket import socket, AF_INET, SOCK_STREAM

class LazyConnection:
    def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
        self.address = address
        self.family = AF_INET
        self.type = SOCK_STREAM
        self.sock = None
    def __enter__(self):
        if self.sock is not None:
            raise RuntimeError('Already connected')

        self.sock = socket(self.family, self.type)
        self.sock.connect(self.address)

        return self.sock
    def __exit__(self, exc_ty, exc_val, tb):
        self.sock.close()
        self.sock = None
In [13]:
from functools import partial
connection = LazyConnection(('www.python.org', 80))

with connection as s:
    s.send(b'GET /index.html HTTP/1.0\r\n')
    s.send(b'Host: www.python.org\r\n')
    s.send(b'\r\n')
    resp = b''.join(iter(partial(s.recv, 8192), b''))
    print(resp)
b'HTTP/1.1 301 Moved Permanently\r\nServer: Varnish\r\nRetry-After: 0\r\nLocation: https://www.python.org/index.html\r\nContent-Length: 0\r\nAccept-Ranges: bytes\r\nDate: Wed, 14 Apr 2021 09:31:05 GMT\r\nVia: 1.1 varnish\r\nConnection: close\r\nX-Served-By: cache-ams21030-AMS\r\nX-Cache: HIT\r\nX-Cache-Hits: 0\r\nX-Timer: S1618392666.941079,VS0,VE0\r\nStrict-Transport-Security: max-age=63072000; includeSubDomains\r\n\r\n'

Stopwatch

In [14]:
import time
@contextmanager
def stopwatch(label: str):
    start = time.time()
    try:
        yield
    finally:
        end = time.time()
        print(f'{label}: {end - start}')


with stopwatch('Sleeping'):
    time.sleep(1)
Sleeping: 1.0047638416290283

Stopwatch with output callable

In [15]:
import time

class Stopwatch:
    def __init__(self, output_callable):
        self.output_callable = output_callable

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, err_type, err_value, err_traceback):
        end = time.time()

        self.output_callable(end - self.start)
In [16]:
import logging

logging.basicConfig()

logger = logging.getLogger('stopwatch')
logger.setLevel(logging.DEBUG)

with Stopwatch(logger.info):
    time.sleep(1)
INFO:stopwatch:1.00174880027771

Enter the same target multiple times

In [17]:
class MultilevelStopwatch:
    def __init__(self):
        self.levels = []

    def __enter__(self):
        self.levels.append(time.time())
        return self

    def __exit__(self, err_type, err_value, err_traceback):
        latest = self.levels.pop()
        end = time.time()

        print(f'Level {len(self.levels)+1} took: {end - latest}')
In [18]:
with MultilevelStopwatch() as ms:
    time.sleep(.5)
    with ms:
        time.sleep(.5)
Level 2 took: 0.5035400390625
Level 1 took: 1.0045082569122314

HTTP Session

reuse of TCP connection to improve performance

In [19]:
# install requests if necessary
# !pip3 install requests
In [20]:
import requests

n = 20

with stopwatch('Using context manager'):
    with requests.Session() as session:
        for _ in range(n):
            session.get("http://httpbin.org/cookies/set/sessioncookie/123456789")

with stopwatch('Establishing HTTP connection for every request'):
    for _ in range(n):
        requests.get("http://httpbin.org/cookies/set/sessioncookie/123456789")
Using context manager: 7.110113143920898
Establishing HTTP connection for every request: 10.991943836212158

Nested context manager

In [21]:
from contextlib import contextmanager


@contextmanager
def get_state(name):
    print("entering:", name)
    try:
        yield name
    finally:
        print("exiting:", name)
In [22]:
with get_state("A") as A, get_state('B') as B, get_state("C") as C:
    print("inside with statement:", A, B, C)
entering: A
entering: B
entering: C
inside with statement: A B C
exiting: C
exiting: B
exiting: A

Nesting with exit stack

Example above written using ExitStack

In [23]:
from contextlib import ExitStack

with ExitStack() as es:
    es.enter_context(get_state('A'))
    es.enter_context(get_state('B'))
    es.enter_context(get_state('C'))
    print('Inside')
entering: A
entering: B
entering: C
Inside
exiting: C
exiting: B
exiting: A

Nested context with errors.

Previously opened contexts exit casually

In [24]:
@contextmanager
def raise_error(name, err):
    print("entering:", name)
    try:
        raise err()
    finally:
        print('exiting:', name)
In [25]:
try:
    with get_state("A") as A, raise_error('B', RuntimeError) as B, get_state("C") as C:
        print('Inside')
except RuntimeError as e:
    print('Caught error', e)
entering: A
entering: B
exiting: B
exiting: A
Caught error 

Extract logging and error handling

In [26]:
import logging
from contextlib import contextmanager
import traceback
import sys

logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format="\n(asctime)s [%(levelname)s] %(message)s",
)


class Divider:

    @contextmanager
    def errorhandler(self):
        try:
            yield
        except ZeroDivisionError:
            print(
                f"Custom handling of Zero Division Error! Printing "
                "only 2 levels of traceback.."
            )
            logging.exception("ZeroDivisionError")

    def __call__(self, a, b):
        """Function that we want to save from nasty error handling logic."""

        with self.errorhandler():
            return a / b


divide = Divider()
divide(2, 0)
ERROR:root:ZeroDivisionError
Traceback (most recent call last):
  File "<ipython-input-26-fc2598d2335e>", line 19, in errorhandler
    yield
  File "<ipython-input-26-fc2598d2335e>", line 31, in __call__
    return a / b
ZeroDivisionError: division by zero
Custom handling of Zero Division Error! Printing only 2 levels of traceback..