Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions examples/benchmark_multithread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Benchmark serial vs parallel Stretch.process() to measure GIL-release speedup.

No audio files required. Run with:
python examples/benchmark_multithread.py [--threads N] [--duration S] [--repeats N]

Expected results with the GIL-release patch (PR #4):
- Serial (8x sequential) : ~400 ms
- Parallel (8 threads) : ~55 ms
- Speedup : ~7x

Without the patch, parallel speedup will be ~1x (threads serialize on the GIL).
"""

import argparse
import timeit
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import python_stretch as m


def parse_args():
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--threads", type=int, default=8, help="Number of parallel threads (default: 8)")
p.add_argument("--duration", type=float, default=4.0, help="Audio duration in seconds (default: 4.0)")
p.add_argument("--repeats", type=int, default=5, help="Number of timeit repeats (default: 5)")
p.add_argument("--semitones", type=float, default=3.0, help="Pitch shift in semitones (default: 3)")
p.add_argument("--time-factor", type=float, default=1.25, help="Time stretch factor (default: 1.25)")
return p.parse_args()


def make_stretch(semitones, time_factor):
ps = m.Signalsmith.Stretch()
ps.setTransposeSemitones(semitones)
ps.setTimeFactor(time_factor)
return ps


def main():
args = parse_args()

sample_rate = 44100
n_samples = int(sample_rate * args.duration)
rng = np.random.default_rng(0)
audio = rng.standard_normal((2, n_samples)).astype(np.float32)

print(f"python-stretch benchmark — GIL release speedup")
print(f" Audio : {args.duration:.1f}s stereo @ {sample_rate} Hz")
print(f" Config : +{args.semitones} semitones, {args.time_factor}x time")
print(f" Threads : {args.threads}")
print(f" Repeats : {args.repeats}")
print()

def run_serial():
for _ in range(args.threads):
ps = make_stretch(args.semitones, args.time_factor)
ps.process(audio)

def worker(_):
ps = make_stretch(args.semitones, args.time_factor)
ps.process(audio)

def run_parallel():
with ThreadPoolExecutor(max_workers=args.threads) as executor:
list(executor.map(worker, range(args.threads)))

t_serial = timeit.timeit(run_serial, number=args.repeats) / args.repeats * 1000
t_parallel = timeit.timeit(run_parallel, number=args.repeats) / args.repeats * 1000
speedup = t_serial / t_parallel

col = 30
print(f"{'':>{col}} {'ms / run':>10}")
print(f"{'Serial ({} x sequential)'.format(args.threads):>{col}} {t_serial:>10.1f}")
print(f"{'Parallel ({} threads)'.format(args.threads):>{col}} {t_parallel:>10.1f}")
print(f"{'Speedup':>{col}} {speedup:>10.2f}x")
print()

if speedup >= args.threads * 0.5:
print(f"✓ GIL released during Stretch.process() — {speedup:.1f}x parallel scaling confirmed.")
elif speedup >= 2:
print(f"~ Partial speedup ({speedup:.1f}x). GIL may be released but other bottlenecks present.")
else:
print(f"✗ No meaningful speedup ({speedup:.1f}x). This build likely does not include the GIL release patch.")


if __name__ == "__main__":
main()
48 changes: 28 additions & 20 deletions src/signalsmith-bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,31 +181,39 @@ struct Stretch{
Buffer<float> inBuffer(inputChannels, paddedInputLength);
Buffer<float> outBuffer(outputChannels, outputLength);

// Seek to the beginning of the input buffer
stretch_.seek(inBuffer, stretch_.inputLatency(), timeFactor_);
// Prepare output data (raw heap allocation — does not touch Python objects)
size_t outShape[2] = {numChannels, outputLength };
float* outData = new float[numChannels * outShape[1]];

// Set offset of inBuffer
inBuffer.setOffset(stretch_.inputLatency());
// === GIL-free section ===========================================
// The stretch computation is pure C++ on raw float buffers; no
// Python objects are touched, so the GIL can be released to give
// ThreadPoolExecutor parallelism a real speedup.
{
nb::gil_scoped_release release;

// PROCESSING
stretch_.process(inBuffer, inputLength, outBuffer, outputLength);
// Seek to the beginning of the input buffer
stretch_.seek(inBuffer, stretch_.inputLatency(), timeFactor_);

// Read the last bit of output without providing any further input
outBuffer.setOffset(outputLength);
stretch_.flush(outBuffer, tailSamples);
// outBuffer.setOffset(tailSamples);
// Set offset of inBuffer
inBuffer.setOffset(stretch_.inputLatency());

// Prepare output data
size_t outShape[2] = {numChannels, outputLength };
float* outData = new float[numChannels * outShape[1]];
// PROCESSING
stretch_.process(inBuffer, inputLength, outBuffer, outputLength);

// Copy from outputChannels to outData
for (size_t i = 0; i < numChannels; ++i) {
std::copy(outputChannels[i] + tailSamples, outputChannels[i] + paddedOutputLength , outData + i * outputLength );
}
// Read the last bit of output without providing any further input
outBuffer.setOffset(outputLength);
stretch_.flush(outBuffer, tailSamples);

// REMEMBER: Reset the stretch processor or we will get an error: free() invalid pointer
stretch_.reset();
// Copy from outputChannels to outData (raw memcpy, no Python)
for (size_t i = 0; i < numChannels; ++i) {
std::copy(outputChannels[i] + tailSamples, outputChannels[i] + paddedOutputLength , outData + i * outputLength );
}

// REMEMBER: Reset the stretch processor or we will get an error: free() invalid pointer
stretch_.reset();
}
// === GIL re-acquired ============================================

// Clean up
for (size_t i = 0; i < numChannels; ++i) {
Expand Down Expand Up @@ -291,7 +299,7 @@ NB_MODULE(Signalsmith, m) {
"----------\n"
"- timeFactor (float): Factor by which time is stretched or compressed (e.g., 0.5 slows down by half, 2.0 doubles speed).")

// PROCESSING
// PROCESSING
.def("process", &Stretch<Sample>::process,
"audio_input"_a,
"Process an input audio buffer and return the stretched or pitch-shifted output.\n\n"
Expand Down
65 changes: 65 additions & 0 deletions tests/test_multithread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import python_stretch as m
from concurrent.futures import ThreadPoolExecutor

NUM_THREADS = 8
SAMPLE_RATE = 44100
SEMITONES = 3
TIME_FACTOR = 1.25


def _make_stretch():
ps = m.Signalsmith.Stretch()
ps.setTransposeSemitones(SEMITONES)
ps.setTimeFactor(TIME_FACTOR)
return ps


def _process_one(audio):
ps = _make_stretch()
return ps.process(audio)


def test_single_thread_determinism():
"""Same input → bit-identical output on repeated serial calls."""
rng = np.random.default_rng(42)
x = rng.standard_normal((2, SAMPLE_RATE * 4)).astype(np.float32)

out_a = _process_one(x)
out_b = _process_one(x)

assert np.array_equal(out_a, out_b), "Serial outputs differ across repeated calls"


def test_parallel_consistency():
"""N independent Stretch instances in a thread pool → match serial reference outputs."""
rng = np.random.default_rng(42)
inputs = [rng.standard_normal((2, SAMPLE_RATE * 4)).astype(np.float32) for _ in range(NUM_THREADS)]

# Serial reference: one Stretch per input, all on main thread
serial_outputs = [_process_one(x) for x in inputs]

# Parallel: one Stretch per thread (each thread owns its instance — the safe pattern)
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
parallel_outputs = list(executor.map(_process_one, inputs))

for i, (serial, parallel) in enumerate(zip(serial_outputs, parallel_outputs)):
assert np.array_equal(serial, parallel), (
f"Thread {i}: parallel output differs from serial reference"
)


def test_cross_run_stability():
"""Same parallel batch repeated twice → identical results across runs."""
rng = np.random.default_rng(42)
inputs = [rng.standard_normal((2, SAMPLE_RATE * 4)).astype(np.float32) for _ in range(NUM_THREADS)]

def run_parallel():
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
return list(executor.map(_process_one, inputs))

outputs_a = run_parallel()
outputs_b = run_parallel()

for i, (a, b) in enumerate(zip(outputs_a, outputs_b)):
assert np.array_equal(a, b), f"Thread {i}: outputs differ across parallel runs"