Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions sdks/java/src/main/java/org/byteveda/taskito/scaler/Scaler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package org.byteveda.taskito.scaler;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashMap;
import java.util.Map;
import org.byteveda.taskito.Taskito;
import org.byteveda.taskito.model.QueueStats;

/**
* Serves queue depth over HTTP for an external autoscaler. Depth is
* {@code pending + running} (outstanding work); the autoscaler divides it by
* {@code targetQueueDepth} to pick a replica count. Start with
* {@link #start(Taskito, ScalerOptions)} and {@link #close()} to stop.
*/
public final class Scaler implements AutoCloseable {
private static final ObjectMapper JSON = new ObjectMapper();

private final HttpServer server;

private Scaler(HttpServer server) {
this.server = server;
}

/** Start the endpoint; binds immediately and serves on a background selector. */
public static Scaler start(Taskito queue, ScalerOptions options) {
HttpServer server;
try {
server = HttpServer.create(new InetSocketAddress(options.host(), options.port()), 0);
} catch (IOException e) {
throw new UncheckedIOException("failed to start the scaler endpoint", e);
}
server.createContext("/api/scaler", exchange -> handleScaler(exchange, queue, options));
server.createContext("/health", Scaler::handleHealth);
server.start();
return new Scaler(server);
}

/** The bound port (useful when {@code port} was 0). */
public int port() {
return server.getAddress().getPort();
}

@Override
public void close() {
server.stop(0);
}

private static void handleScaler(HttpExchange exchange, Taskito queue, ScalerOptions options) throws IOException {
if (!"GET".equals(exchange.getRequestMethod())) {
send(exchange, 405, Map.of("error", "method not allowed"));
return;
}
String queueName = queryParam(exchange, "queue");
if (queueName == null) {
queueName = options.queue();
}
try {
QueueStats stats = queueName == null ? queue.stats() : queue.statsByQueue(queueName);
long depth = stats.pending + stats.running;
Map<String, Object> body = new LinkedHashMap<>();
body.put("metricValue", depth);
body.put("targetValue", options.targetQueueDepth());
body.put("queueName", queueName == null ? "all" : queueName);
send(exchange, 200, body);
} catch (RuntimeException e) {
// Never leak backend internals to the scaler caller.
send(exchange, 500, Map.of("error", "failed to read queue stats"));
}
}

private static void handleHealth(HttpExchange exchange) throws IOException {
send(exchange, 200, Map.of("status", "ok"));
}

private static String queryParam(HttpExchange exchange, String key) {
String query = exchange.getRequestURI().getQuery();
if (query == null) {
return null;
}
for (String pair : query.split("&")) {
int eq = pair.indexOf('=');
if (eq > 0 && pair.substring(0, eq).equals(key)) {
return pair.substring(eq + 1);
}
}
return null;
}

private static void send(HttpExchange exchange, int status, Map<String, Object> body) throws IOException {
byte[] bytes;
try {
bytes = JSON.writeValueAsBytes(body);
} catch (Exception e) {
bytes = "{}".getBytes(StandardCharsets.UTF_8);
status = 500;
}
exchange.getResponseHeaders().set("Content-Type", "application/json");
exchange.sendResponseHeaders(status, bytes.length);
try (OutputStream out = exchange.getResponseBody()) {
out.write(bytes);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.byteveda.taskito.scaler;

/**
* Configures the {@link Scaler} HTTP endpoint.
*
* @param port the port to bind ({@code 0} picks an ephemeral port)
* @param host the bind address (defaults to {@code 0.0.0.0})
* @param targetQueueDepth the depth the autoscaler targets per replica (must be &gt; 0)
* @param queue the queue to report on, or {@code null} for all queues
*/
public record ScalerOptions(int port, String host, int targetQueueDepth, String queue) {
public ScalerOptions {
if (targetQueueDepth <= 0) {
throw new IllegalArgumentException("targetQueueDepth must be > 0");
}
if (host == null || host.isBlank()) {
host = "0.0.0.0";
}
}

/** Defaults: port 9090, all queues, target depth 10. */
public static ScalerOptions defaults() {
return new ScalerOptions(9090, "0.0.0.0", 10, null);
}

/** Bind to {@code port} with the other defaults. */
public static ScalerOptions onPort(int port) {
return new ScalerOptions(port, "0.0.0.0", 10, null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/**
* A tiny HTTP endpoint that exposes queue depth for an external autoscaler
* (e.g. KEDA's metrics-api scaler). {@link org.byteveda.taskito.scaler.Scaler}
* serves {@code GET /api/scaler} ({@code metricValue}/{@code targetValue}) and
* {@code GET /health}. Observability (metrics export) is left to the contrib
* middleware; this only reports depth for scaling decisions.
*/
package org.byteveda.taskito.scaler;
97 changes: 97 additions & 0 deletions sdks/java/src/test/java/org/byteveda/taskito/ScalerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package org.byteveda.taskito;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.file.Path;
import org.byteveda.taskito.scaler.Scaler;
import org.byteveda.taskito.scaler.ScalerOptions;
import org.byteveda.taskito.task.EnqueueOptions;
import org.byteveda.taskito.task.Task;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.io.TempDir;

class ScalerTest {

private static final ObjectMapper JSON = new ObjectMapper();
private static final Task<Integer> TASK = Task.of("s.task", Integer.class);

@Test
@Timeout(30)
void reportsQueueDepthAndHealth(@TempDir Path dir) throws Exception {
try (Taskito queue =
Taskito.builder().url(dir.resolve("s.db").toString()).open()) {
queue.enqueue(TASK, 1);
queue.enqueue(TASK, 2);
try (Scaler scaler = Scaler.start(queue, new ScalerOptions(0, "127.0.0.1", 5, null))) {
HttpClient client = HttpClient.newHttpClient();
String base = "http://127.0.0.1:" + scaler.port();

HttpResponse<String> scale = get(client, base + "/api/scaler");
assertEquals(200, scale.statusCode());
JsonNode body = JSON.readTree(scale.body());
assertEquals(2, body.get("metricValue").asLong());
assertEquals(5, body.get("targetValue").asInt());
assertEquals("all", body.get("queueName").asText());

HttpResponse<String> health = get(client, base + "/health");
assertEquals(200, health.statusCode());
assertEquals("ok", JSON.readTree(health.body()).get("status").asText());
}
}
}

@Test
@Timeout(30)
void filtersDepthByQueue(@TempDir Path dir) throws Exception {
try (Taskito queue =
Taskito.builder().url(dir.resolve("sq.db").toString()).open()) {
queue.enqueue(TASK, 1); // default queue
queue.enqueue(TASK, 2, EnqueueOptions.builder().queue("high").build());
queue.enqueue(TASK, 3, EnqueueOptions.builder().queue("high").build());
try (Scaler scaler = Scaler.start(queue, new ScalerOptions(0, "127.0.0.1", 5, null))) {
HttpClient client = HttpClient.newHttpClient();
String base = "http://127.0.0.1:" + scaler.port();

HttpResponse<String> scoped = get(client, base + "/api/scaler?queue=high");
assertEquals(200, scoped.statusCode());
JsonNode body = JSON.readTree(scoped.body());
assertEquals(2, body.get("metricValue").asLong(), "only the 'high' queue counts");
assertEquals("high", body.get("queueName").asText());
}
}
}

@Test
@Timeout(30)
void rejectsNonGetRequests(@TempDir Path dir) throws Exception {
try (Taskito queue =
Taskito.builder().url(dir.resolve("sm.db").toString()).open()) {
try (Scaler scaler = Scaler.start(queue, new ScalerOptions(0, "127.0.0.1", 5, null))) {
HttpClient client = HttpClient.newHttpClient();
HttpRequest post = HttpRequest.newBuilder(
URI.create("http://127.0.0.1:" + scaler.port() + "/api/scaler"))
.POST(HttpRequest.BodyPublishers.noBody())
.build();
HttpResponse<String> response = client.send(post, HttpResponse.BodyHandlers.ofString());
assertEquals(405, response.statusCode());
}
}
}

@Test
void rejectsNonPositiveTarget() {
assertThrows(IllegalArgumentException.class, () -> new ScalerOptions(0, "127.0.0.1", 0, null));
}

private static HttpResponse<String> get(HttpClient client, String url) throws Exception {
return client.send(HttpRequest.newBuilder(URI.create(url)).GET().build(), HttpResponse.BodyHandlers.ofString());
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.