Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions sdks/java/src/main/java/org/byteveda/taskito/batch/Batcher.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package org.byteveda.taskito.batch;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.byteveda.taskito.Taskito;
import org.byteveda.taskito.task.Task;

/**
* Buffers payloads for one task and enqueues them in a single
* {@code enqueueMany} call when the buffer reaches {@code maxBatch} or
* {@code maxDelay} elapses since the first buffered item. Thread-safe;
* {@link #close()} flushes what remains. Use with try-with-resources.
*
* @param <T> the task's payload type
*/
public final class Batcher<T> implements AutoCloseable {
private final Taskito queue;
private final Task<T> task;
private final int maxBatch;
private final long maxDelayNanos;
private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(Batcher::daemon);
private final Object lock = new Object();
private final List<T> buffer = new ArrayList<>();
private ScheduledFuture<?> pendingFlush;
private boolean closed; // guarded by lock

public Batcher(Taskito queue, Task<T> task, int maxBatch, Duration maxDelay) {
if (maxBatch <= 0) {
throw new IllegalArgumentException("maxBatch must be > 0");
}
if (maxDelay == null || maxDelay.isNegative() || maxDelay.isZero()) {
throw new IllegalArgumentException("maxDelay must be positive");
}
this.queue = queue;
this.task = task;
this.maxBatch = maxBatch;
// Nanoseconds, not millis: toMillis() would truncate a sub-millisecond
// delay to 0 and flush eagerly instead of honoring the requested delay.
this.maxDelayNanos = maxDelay.toNanos();
}

public static <T> Batcher<T> of(Taskito queue, Task<T> task, int maxBatch, Duration maxDelay) {
return new Batcher<>(queue, task, maxBatch, maxDelay);
}

/**
* Buffer {@code payload}. Returns the job ids if this call triggered a flush
* (the buffer reached {@code maxBatch}), otherwise an empty list.
*/
public List<String> add(T payload) {
synchronized (lock) {
if (closed) {
throw new IllegalStateException("batcher is closed");
}
buffer.add(payload);
if (buffer.size() >= maxBatch) {
return flushLocked();
}
scheduleFlush();
return List.of();
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
}

/** Enqueue any buffered payloads now; returns their job ids (empty if none). */
public List<String> flush() {
synchronized (lock) {
return flushLocked();
}
}

@Override
public void close() {
synchronized (lock) {
if (closed) {
return;
}
closed = true;
// Flush and mark closed atomically so no add() can slip in and
// schedule a delayed flush that shutdownNow() would then cancel.
flushLocked();
}
scheduler.shutdownNow();
}

private List<String> flushLocked() {
if (pendingFlush != null) {
pendingFlush.cancel(false);
pendingFlush = null;
}
if (buffer.isEmpty()) {
return List.of();
}
List<T> batch = new ArrayList<>(buffer);
// Enqueue before clearing: if enqueueMany throws, the buffer keeps the
// payloads so a delayed-flush failure doesn't silently drop them.
List<String> ids = queue.enqueueMany(task, batch);
buffer.clear();
return ids;
}

private void scheduleFlush() {
if (pendingFlush == null) {
pendingFlush = scheduler.schedule(this::flush, maxDelayNanos, TimeUnit.NANOSECONDS);
}
}

private static Thread daemon(Runnable runnable) {
Thread thread = new Thread(runnable, "taskito-batcher");
thread.setDaemon(true);
return thread;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/**
* Producer-side batching: {@link org.byteveda.taskito.batch.Batcher} buffers
* payloads and flushes them in one {@code enqueueMany} when the batch fills or a
* delay elapses. The worker side already batches via the worker's
* {@code batchSize} option (which drives the core batch dequeue).
*/
package org.byteveda.taskito.batch;
70 changes: 70 additions & 0 deletions sdks/java/src/test/java/org/byteveda/taskito/BatcherTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.byteveda.taskito;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
import org.byteveda.taskito.batch.Batcher;
import org.byteveda.taskito.task.Task;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.io.TempDir;

class BatcherTest {

private static final Task<Integer> TASK = Task.of("b.task", Integer.class);

@Test
void flushesWhenBatchFull(@TempDir Path dir) {
try (Taskito queue =
Taskito.builder().url(dir.resolve("b.db").toString()).open();
Batcher<Integer> batcher = Batcher.of(queue, TASK, 3, Duration.ofSeconds(60))) {
assertTrue(batcher.add(1).isEmpty());
assertTrue(batcher.add(2).isEmpty());
List<String> ids = batcher.add(3);
assertEquals(3, ids.size());
assertEquals(3, queue.stats().pending);
}
}

@Test
void flushesRemainderOnClose(@TempDir Path dir) {
try (Taskito queue =
Taskito.builder().url(dir.resolve("b.db").toString()).open()) {
try (Batcher<Integer> batcher = Batcher.of(queue, TASK, 100, Duration.ofSeconds(60))) {
batcher.add(1);
batcher.add(2);
}
assertEquals(2, queue.stats().pending);
}
}

@Test
void addAfterCloseThrows(@TempDir Path dir) {
try (Taskito queue =
Taskito.builder().url(dir.resolve("b.db").toString()).open()) {
Batcher<Integer> batcher = Batcher.of(queue, TASK, 100, Duration.ofSeconds(60));
batcher.close();
assertThrows(IllegalStateException.class, () -> batcher.add(1));
}
}

@Test
@Timeout(30)
void flushesAfterDelay(@TempDir Path dir) throws Exception {
try (Taskito queue =
Taskito.builder().url(dir.resolve("b.db").toString()).open();
Batcher<Integer> batcher = Batcher.of(queue, TASK, 100, Duration.ofMillis(200))) {
batcher.add(1);
batcher.add(2);
long deadline = System.nanoTime() + Duration.ofSeconds(10).toNanos();
while (queue.stats().pending < 2 && System.nanoTime() < deadline) {
Thread.sleep(50);
}
assertEquals(2, queue.stats().pending);
}
}
}