Skip to content

Latest commit

 

History

History
112 lines (93 loc) · 3.3 KB

File metadata and controls

112 lines (93 loc) · 3.3 KB

(tutorial-basic)=

Array API Tutorial

In this tutorial we're going to show the migration from the array consumer point of view for a simple graph algorithm.

The example presented here comes from graphblas-algorithms. library. There we can find the HITS algorithm, used for the link analysis for estimating prominence in sparse networks.

The inlined and slightly simplified (without "authority" feature) implementation looks like this:

def hits(G, max_iter=100, tol=1.0e-8, normalized=True):
    N = len(G)
    h = Vector(float, N, name="h")
    a = Vector(float, N, name="a")
    h << 1.0 / N
    # Power iteration: make up to max_iter iterations
    A = G._A
    hprev = Vector(float, N, name="h_prev")
    for _i in range(max_iter):
        hprev, h = h, hprev
        a << hprev @ A
        h << A @ a
        h *= 1.0 / h.reduce(monoid.max).get(0)
        if is_converged(hprev, h, tol):
            break
    else:
        raise ConvergenceFailure(max_iter)
    if normalized:
        h *= 1.0 / h.reduce().get(0)
        a *= 1.0 / a.reduce().get(0)
    return h, a

def is_converged(xprev, x, tol):
    xprev << binary.minus(xprev | x)
    xprev << unary.abs(xprev)
    return xprev.reduce().get(0) < xprev.size * tol

We can see that the API is specific to the GraphBLAS array object. There is Vector constructor, overloaded << for assigning new values, and reduce/get for reductions. We need to replace them, and, by convention, we will use xp namespace for calling respective functions.

First we want to make sure we construct arrays in an agnostic way:

h = xp.full(N, 1.0 / N)
A = xp.asarray(G.A)

Then, instead of reduce calls we use appropriate reducing functions from the Array API:

h = h / xp.max(h)
# ...
h = h / xp.sum(xp.abs(h))
a = a / xp.sum(xp.abs(a))
# ...
err = xp.sum(xp.abs(...))

We replace custom binary operation with the Array API counterpart:

...(x - xprev)

And last but not least, let's ensure that the result of the convergence condition is a scalar coming from our API:

err < xp.asarray(N * tol)

The rewrite is complete now, we can assemble all constituent parts into a full implementation:

def hits(G, max_iter=100, tol=1.0e-8, normalized=True):
    N = len(G)
    h = xp.full(N, 1.0 / N)
    A = xp.asarray(G.A)
    # Power iteration: make up to max_iter iterations
    for _i in range(max_iter):
        hprev = h
        a = hprev @ A
        h = A @ a
        h = h / xp.max(h)
        if is_converged(hprev, h, N, tol):
            break
    else:
        raise Exception("Didn't converge")
    if normalized:
        h = h / xp.sum(xp.abs(h))
        a = a / xp.sum(xp.abs(a))
    return h, a

def is_converged(xprev, x, N, tol):
    err = xp.sum(xp.abs(x - xprev))
    return err < xp.asarray(N * tol)

At this point the actual execution depends only on xp namespace, and replacing that one variable allow us to switch from e.g. NumPy arrays to a JAX execution on a GPU. This allows us to be more flexible, and, for example use lazy evaluation and JIT compile a loop body with JAX's JIT compilation.