Skip to content

Lock contention inside _Py_Specialize_LoadGlobal under free threading #152075

Description

@hawkinsp

Bug report

Bug description:

While working on improving the run time of the JAX test suite with high thread concurrency under free threading, a lock contention profile pointed me to high lock contention in _Py_Specialize_LoadGlobal

Here's a synthetic benchmark that I generated with the assistance of AI.

#!/usr/bin/env python3
"""Benchmark to demonstrate LOAD_GLOBAL specialization contention in free-threaded Python.

This benchmark spawns multiple threads that frequently read global variables and
builtins (triggering bytecode specialization). An optional background thread
periodically modifies the globals dictionary.
"""

import argparse
import sys
import threading
import time

# Globals to access
G1 = 1
G2 = 2
G3 = 3
G4 = 4
G5 = 5
G6 = 6
G7 = 7
G8 = 8


def reader_worker(num_iters, stop_event):
    # Access globals in a loop to trigger LOAD_GLOBAL
    for _ in range(num_iters):
        if stop_event.is_set():
            break
        a = G1
        b = G2
        c = G3
        d = G4
        e = G5
        f = G6
        g = G7
        h = G8

        # Access builtins
        i = len
        j = sum
        k = abs

        # Trivial use of the loaded values
        _ = a + b + c + d + e + f + g + h


def invalidator_worker(stop_event, interval):
    while not stop_event.is_set():
        # Modify globals to increment keys version and invalidate caches
        globals()["_temp_key"] = 1
        del globals()["_temp_key"]
        if interval > 0:
            time.sleep(interval)


def run_benchmark(num_threads, iters_per_thread, invalidate_interval):
    stop_event = threading.Event()

    # Start invalidator if requested
    invalidator_thread = None
    if invalidate_interval is not None:
        invalidator_thread = threading.Thread(
            target=invalidator_worker,
            args=(stop_event, invalidate_interval),
            daemon=True,
        )
        invalidator_thread.start()

    # Start readers
    threads = []
    start_time = time.perf_counter()

    for _ in range(num_threads):
        t = threading.Thread(target=reader_worker, args=(iters_per_thread, stop_event))
        threads.append(t)
        t.start()

    for t in threads:
        t.join()

    end_time = time.perf_counter()

    # Stop invalidator
    stop_event.set()
    if invalidator_thread:
        invalidator_thread.join()

    return end_time - start_time


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--threads", type=int, default=64, help="Number of reader threads (default: 64)"
    )
    parser.add_argument(
        "--iters",
        type=int,
        default=10000000,
        help="Iterations per reader thread (default: 10,000,000)",
    )
    parser.add_argument(
        "--interval",
        type=float,
        default=0.0001,
        help="Invalidation interval in seconds. Use negative value to disable invalidation. (default: 0.0001)",
    )

    args = parser.parse_args()

    invalidate_interval = args.interval if args.interval >= 0 else None

    print(f"Python Executable: {sys.executable}")
    print(f"Python Version: {sys.version}")
    print(f"GIL Enabled: {getattr(sys, '_is_gil_enabled', lambda: 'Unknown')()}")
    print(
        f"Configuration: threads={args.threads}, iters={args.iters:,}, "
        f"invalidate_interval={invalidate_interval}"
    )
    print("Running benchmark...")

    duration = run_benchmark(args.threads, args.iters, invalidate_interval)

    print(f"Finished in {duration:.4f} seconds")


if __name__ == "__main__":
    main()

and I'm running this benchmark on a cloud VM with these characteristics:

Architecture:                x86_64
...
CPU(s):                      128
  On-line CPU(s) list:       0-127
Vendor ID:                   AuthenticAMD
  Model name:                AMD EPYC 7B13
# With no invalidation:
$ python benchmark_contention.py  --interval -1
...
Python Version: 3.15.0b2+dev free-threading build (heads/3.15:ba0cae13cea, Jun 24 2026, 02:50:32) [GCC 15.2.0]
GIL Enabled: False
Configuration: threads=64, iters=10,000,000, invalidate_interval=None
Running benchmark...
Finished in 2.1970 seconds

# With invalidation:
$ python benchmark_contention.py
Python Version: 3.15.0b2+dev free-threading build (heads/3.15:ba0cae13cea, Jun 24 2026, 02:50:32) [GCC 15.2.0]
GIL Enabled: False
Configuration: threads=64, iters=10,000,000, invalidate_interval=0.0001
Running benchmark...
Finished in 23.2520 seconds

However I thought of trying this patch, which immediately abandons the attempt to specialize if acquiring globals/builtins critical section would block:

diff --git a/Include/critical_section.h b/Include/critical_section.h
index 732bfab7ecf..b90bc7fda2a 100644
--- a/Include/critical_section.h
+++ b/Include/critical_section.h
@@ -68,6 +68,9 @@ PyCriticalSection2_Begin(PyCriticalSection2 *c, PyObject *a, PyObject *b);
 PyAPI_FUNC(void)
 PyCriticalSection2_End(PyCriticalSection2 *c);
 
+PyAPI_FUNC(int)
+PyCriticalSection2_TryBegin(PyCriticalSection2 *c, PyObject *a, PyObject *b);
+
 // These are definitions for the stable ABI. For GIL-ful builds they're
 // conditionally redefined as no-ops in cpython/critical_section.h.
 
diff --git a/Include/internal/pycore_critical_section.h b/Include/internal/pycore_critical_section.h
index 51d99d74ca1..d32fc2fdf67 100644
--- a/Include/internal/pycore_critical_section.h
+++ b/Include/internal/pycore_critical_section.h
@@ -193,6 +193,56 @@ _PyCriticalSection2_Begin(PyThreadState *tstate, PyCriticalSection2 *c, PyObject
     _PyCriticalSection2_BeginMutex(tstate, c, &a->ob_mutex, &b->ob_mutex);
 }
 
+static inline int
+_PyCriticalSection_TryBeginMutex(PyThreadState *tstate, PyCriticalSection *c, PyMutex *m)
+{
+    if (PyMutex_LockFast(m)) {
+        c->_cs_mutex = m;
+        c->_cs_prev = tstate->critical_section;
+        tstate->critical_section = (uintptr_t)c;
+        return 1;
+    }
+    return 0;
+}
+
+static inline int
+_PyCriticalSection2_TryBeginMutex(PyThreadState *tstate, PyCriticalSection2 *c, PyMutex *m1, PyMutex *m2)
+{
+    if (m1 == m2) {
+        c->_cs_mutex2 = NULL;
+        return _PyCriticalSection_TryBeginMutex(tstate, &c->_cs_base, m1);
+    }
+
+    if ((uintptr_t)m2 < (uintptr_t)m1) {
+        PyMutex *tmp = m1;
+        m1 = m2;
+        m2 = tmp;
+    }
+
+    if (PyMutex_LockFast(m1)) {
+        if (PyMutex_LockFast(m2)) {
+            c->_cs_base._cs_mutex = m1;
+            c->_cs_mutex2 = m2;
+            c->_cs_base._cs_prev = tstate->critical_section;
+
+            uintptr_t p = (uintptr_t)c | _Py_CRITICAL_SECTION_TWO_MUTEXES;
+            tstate->critical_section = p;
+            return 1;
+        }
+        else {
+            PyMutex_Unlock(m1);
+            return 0;
+        }
+    }
+    return 0;
+}
+
+static inline int
+_PyCriticalSection2_TryBegin(PyThreadState *tstate, PyCriticalSection2 *c, PyObject *a, PyObject *b)
+{
+    return _PyCriticalSection2_TryBeginMutex(tstate, c, &a->ob_mutex, &b->ob_mutex);
+}
+
 static inline void
 _PyCriticalSection2_End(PyThreadState *tstate, PyCriticalSection2 *c)
 {
diff --git a/Python/critical_section.c b/Python/critical_section.c
index dbee6f236a7..2d81425cee2 100644
--- a/Python/critical_section.c
+++ b/Python/critical_section.c
@@ -217,3 +217,14 @@ PyCriticalSection2_End(PyCriticalSection2 *c)
     _PyCriticalSection2_End(_PyThreadState_GET(), c);
 #endif
 }
+
+#undef PyCriticalSection2_TryBegin
+int
+PyCriticalSection2_TryBegin(PyCriticalSection2 *c, PyObject *a, PyObject *b)
+{
+#ifdef Py_GIL_DISABLED
+    return _PyCriticalSection2_TryBegin(_PyThreadState_GET(), c, a, b);
+#else
+    return 1;
+#endif
+}
diff --git a/Python/specialize.c b/Python/specialize.c
index 459e69de570..53a126d9f48 100644
--- a/Python/specialize.c
+++ b/Python/specialize.c
@@ -1447,9 +1447,18 @@ _Py_Specialize_LoadGlobal(
     PyObject *globals, PyObject *builtins,
     _Py_CODEUNIT *instr, PyObject *name)
 {
-    Py_BEGIN_CRITICAL_SECTION2(globals, builtins);
+#ifdef Py_GIL_DISABLED
+    PyCriticalSection2 cs;
+    if (PyCriticalSection2_TryBegin(&cs, globals, builtins)) {
+        specialize_load_global_lock_held(globals, builtins, instr, name);
+        PyCriticalSection2_End(&cs);
+    }
+    else {
+        unspecialize(instr);
+    }
+#else
     specialize_load_global_lock_held(globals, builtins, instr, name);
-    Py_END_CRITICAL_SECTION2();
+#endif
 }
 
 static int

and with that change I get these timings:

$ python benchmark_contention.py  --interval -1
Python Version: 3.15.0b2+dev free-threading build (heads/3.15-dirty:ba0cae13cea, Jun 24 2026, 01:28:11) [GCC 15.2.0]
GIL Enabled: False
Configuration: threads=64, iters=10,000,000, invalidate_interval=None
Running benchmark...
Finished in 1.9692 seconds

$ python benchmark_contention.py
Python Version: 3.15.0b2+dev free-threading build (heads/3.15-dirty:ba0cae13cea, Jun 24 2026, 01:28:11) [GCC 15.2.0]
GIL Enabled: False
Configuration: threads=64, iters=10,000,000, invalidate_interval=0.0001
Running benchmark...
Finished in 11.0028 seconds

Might we land something like that?

CPython versions tested on:

3.15

Operating systems tested on:

Linux

Metadata

Metadata

Assignees

No one assigned
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions