diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 8241fa9b..5441e0a9 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -275,13 +275,13 @@ def _save( ) -def _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step): +def _clip_to_end(tnext, t1, t1_clip_floor, keep_step): # The tolerance of ~100 ULP's means that we don't end up with too-small intervals # for dense output, which then gives numerically unstable answers due to floating # point errors. - clip = tnext > t1_clip_floor - tclip = jnp.where(keep_step, t1, tprev + 0.5 * (t1 - tprev)) - return jnp.where(clip, tclip, tnext) + # Only clip on accepted steps: on a rejection the controller has just shrunk its + # proposal, so overriding it can cause an infinite rejection loop (see #756). + return jnp.where(keep_step & (tnext > t1_clip_floor), t1, tnext) def _maybe_static(static_x: ArrayLike | None, x: ArrayLike) -> ArrayLike: @@ -410,7 +410,7 @@ def body_fun_aux(state): # tprev = jnp.minimum(tprev, t1) - tnext = _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step) + tnext = _clip_to_end(tnext, t1, t1_clip_floor, keep_step) progress_meter_state = progress_meter.step( state.progress_meter_state, linear_rescale(t0, tprev, t1) diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index d785d16e..4a51336f 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -364,6 +364,64 @@ def test_jump_at_t1_with_large_t1_in_float32(): assert sol.ts == jnp.array([t1]) +# https://github.com/patrick-kidger/diffrax/issues/756 +# In float32 with a large t1, the 100-ULP clip window is wide enough to contain +# the PID controller's desired step. If tprev sits just outside the clip window +# but the controller's proposed tnext is inside it, the rejected step must not +# be overridden with a fixed midpoint between tprev and t1: tprev does not move +# on rejection, so a fixed midpoint would be retried forever. +def test_clip_to_end_does_not_override_rejected_step(): + from diffrax._integrate import _clip_to_end + + t1 = jnp.array(600.0, dtype=jnp.float32) + t1_clip_floor = t1 + for _ in range(100): + t1_clip_floor = eqxi.prevbefore(t1_clip_floor) + # tnext lies inside the clip window; the controller has just rejected the + # previous step, so keep_step is False. Expect the controller's tnext to + # pass through unchanged (so the next iteration sees a strictly shrunken + # proposal rather than the same midpoint). + tnext = jnp.array(599.996, dtype=jnp.float32) + assert tnext > t1_clip_floor + out = _clip_to_end(tnext, t1, t1_clip_floor, jnp.array(False)) + assert out == tnext + # On an accepted step we still snap to t1. + out = _clip_to_end(tnext, t1, t1_clip_floor, jnp.array(True)) + assert out == t1 + # And outside the clip window we never touch tnext. + tnext_far = jnp.array(599.0, dtype=jnp.float32) + for keep in (jnp.array(True), jnp.array(False)): + assert _clip_to_end(tnext_far, t1, t1_clip_floor, keep) == tnext_far + + +# End-to-end version of the above: a harmonic oscillator in float32 with t1=600 +# and a tolerance that drives the PID controller's desired step into the 100-ULP +# clip window. Before the #756 fix this hung at max_steps with tprev pinned to +# the clip floor; after the fix it completes in ~30 steps. +def test_no_infinite_reject_loop_at_t1_in_float32(): + t1 = jnp.float32(600.0) + omega = jnp.float32(1000.0) + + def vf(t, y, args): + y0, y1 = y + return jnp.stack([y1, -omega * omega * y0]) + + sol = diffrax.diffeqsolve( + diffrax.ODETerm(vf), + diffrax.Tsit5(), + t0=jnp.float32(599.99), + t1=t1, + dt0=None, + y0=jnp.array([1.0, 0.0], dtype=jnp.float32), + stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-4, safety=0.9), + saveat=diffrax.SaveAt(t1=True), + max_steps=1000, + ) + assert sol.result == diffrax.RESULTS.successful + assert sol.ts is not None + assert sol.ts[-1] == t1 + + # https://github.com/patrick-kidger/diffrax/issues/713 def test_t0_at_jump_time(): jump_time = 0.98