Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ def _save(
)


def _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step):
def _clip_to_end(tnext, t1, t1_clip_floor, keep_step):
# The tolerance of ~100 ULP's means that we don't end up with too-small intervals
# for dense output, which then gives numerically unstable answers due to floating
# point errors.
clip = tnext > t1_clip_floor
tclip = jnp.where(keep_step, t1, tprev + 0.5 * (t1 - tprev))
return jnp.where(clip, tclip, tnext)
# Only clip on accepted steps: on a rejection the controller has just shrunk its
# proposal, so overriding it can cause an infinite rejection loop (see #756).
return jnp.where(keep_step & (tnext > t1_clip_floor), t1, tnext)


def _maybe_static(static_x: ArrayLike | None, x: ArrayLike) -> ArrayLike:
Expand Down Expand Up @@ -410,7 +410,7 @@ def body_fun_aux(state):
#

tprev = jnp.minimum(tprev, t1)
tnext = _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step)
tnext = _clip_to_end(tnext, t1, t1_clip_floor, keep_step)

progress_meter_state = progress_meter.step(
state.progress_meter_state, linear_rescale(t0, tprev, t1)
Expand Down
58 changes: 58 additions & 0 deletions test/test_adaptive_stepsize_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,64 @@ def test_jump_at_t1_with_large_t1_in_float32():
assert sol.ts == jnp.array([t1])


# https://github.com/patrick-kidger/diffrax/issues/756
# In float32 with a large t1, the 100-ULP clip window is wide enough to contain
# the PID controller's desired step. If tprev sits just outside the clip window
# but the controller's proposed tnext is inside it, the rejected step must not
# be overridden with a fixed midpoint between tprev and t1: tprev does not move
# on rejection, so a fixed midpoint would be retried forever.
def test_clip_to_end_does_not_override_rejected_step():
from diffrax._integrate import _clip_to_end

t1 = jnp.array(600.0, dtype=jnp.float32)
t1_clip_floor = t1
for _ in range(100):
t1_clip_floor = eqxi.prevbefore(t1_clip_floor)
# tnext lies inside the clip window; the controller has just rejected the
# previous step, so keep_step is False. Expect the controller's tnext to
# pass through unchanged (so the next iteration sees a strictly shrunken
# proposal rather than the same midpoint).
tnext = jnp.array(599.996, dtype=jnp.float32)
assert tnext > t1_clip_floor
out = _clip_to_end(tnext, t1, t1_clip_floor, jnp.array(False))
assert out == tnext
# On an accepted step we still snap to t1.
out = _clip_to_end(tnext, t1, t1_clip_floor, jnp.array(True))
assert out == t1
# And outside the clip window we never touch tnext.
tnext_far = jnp.array(599.0, dtype=jnp.float32)
for keep in (jnp.array(True), jnp.array(False)):
assert _clip_to_end(tnext_far, t1, t1_clip_floor, keep) == tnext_far


# End-to-end version of the above: a harmonic oscillator in float32 with t1=600
# and a tolerance that drives the PID controller's desired step into the 100-ULP
# clip window. Before the #756 fix this hung at max_steps with tprev pinned to
# the clip floor; after the fix it completes in ~30 steps.
def test_no_infinite_reject_loop_at_t1_in_float32():
t1 = jnp.float32(600.0)
omega = jnp.float32(1000.0)

def vf(t, y, args):
y0, y1 = y
return jnp.stack([y1, -omega * omega * y0])

sol = diffrax.diffeqsolve(
diffrax.ODETerm(vf),
diffrax.Tsit5(),
t0=jnp.float32(599.99),
t1=t1,
dt0=None,
y0=jnp.array([1.0, 0.0], dtype=jnp.float32),
stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-4, safety=0.9),
saveat=diffrax.SaveAt(t1=True),
max_steps=1000,
)
assert sol.result == diffrax.RESULTS.successful
assert sol.ts is not None
assert sol.ts[-1] == t1


# https://github.com/patrick-kidger/diffrax/issues/713
def test_t0_at_jump_time():
jump_time = 0.98
Expand Down
Loading