Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions accessjwt/doc.go → access/jwt/doc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Package accessjwt issues and verifies authkit-owned access JWTs.
// Package jwt issues and verifies authkit-owned access JWTs.
//
// Access JWTs authenticate an authkit principal. They intentionally carry only
// principal identity and token metadata; authorization data stays in authkit
// storage and is evaluated at request time.
package accessjwt
package jwt
26 changes: 13 additions & 13 deletions accessjwt/issuer.go → access/jwt/issuer.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package accessjwt
package jwt

import (
"context"
Expand All @@ -11,7 +11,7 @@ import (
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwt"
jwxjwt "github.com/lestrrat-go/jwx/v3/jwt"
)

// Issuer issues signed authkit access JWTs.
Expand All @@ -34,10 +34,10 @@ func NewIssuer(opts IssuerOptions) (*Issuer, error) {
return nil, err
}
if opts.TTL <= 0 {
return nil, errors.New("accessjwt: TTL must be positive")
return nil, errors.New("jwt: TTL must be positive")
}
if opts.SigningKey == nil {
return nil, errors.New("accessjwt: signing key is required")
return nil, errors.New("jwt: signing key is required")
}
if err := validateKeyID("signing key", opts.SigningKey); err != nil {
return nil, err
Expand Down Expand Up @@ -82,15 +82,15 @@ func (i *Issuer) IssueToken(ctx context.Context, req IssueRequest) (IssuedToken,

tokenID, tokenIDErr := i.tokenID()
if tokenIDErr != nil {
return IssuedToken{}, fmt.Errorf("accessjwt: generate token ID: %w", tokenIDErr)
return IssuedToken{}, fmt.Errorf("jwt: generate token ID: %w", tokenIDErr)
}
if validationErr := validateRequiredString("token ID", tokenID); validationErr != nil {
return IssuedToken{}, validationErr
}

issuedAt := i.clock()
expiresAt := issuedAt.Add(i.ttl)
token, err := jwt.NewBuilder().
token, err := jwxjwt.NewBuilder().
Issuer(i.issuer).
Subject(req.PrincipalID).
Audience([]string{i.audience}).
Expand All @@ -99,20 +99,20 @@ func (i *Issuer) IssueToken(ctx context.Context, req IssueRequest) (IssuedToken,
JwtID(tokenID).
Build()
if err != nil {
return IssuedToken{}, fmt.Errorf("accessjwt: build token: %w", err)
return IssuedToken{}, fmt.Errorf("jwt: build token: %w", err)
}

headers := jws.NewHeaders()
if headerErr := headers.Set(jws.TypeKey, TokenType); headerErr != nil {
return IssuedToken{}, fmt.Errorf("accessjwt: set token type: %w", headerErr)
return IssuedToken{}, fmt.Errorf("jwt: set token type: %w", headerErr)
}

signed, err := jwt.Sign(
signed, err := jwxjwt.Sign(
token,
jwt.WithKey(i.algorithm, i.signingKey, jws.WithProtectedHeaders(headers)),
jwxjwt.WithKey(i.algorithm, i.signingKey, jws.WithProtectedHeaders(headers)),
)
if err != nil {
return IssuedToken{}, fmt.Errorf("accessjwt: sign token: %w", err)
return IssuedToken{}, fmt.Errorf("jwt: sign token: %w", err)
}

return IssuedToken{
Expand All @@ -126,10 +126,10 @@ func (i *Issuer) IssueToken(ctx context.Context, req IssueRequest) (IssuedToken,

func validateRequiredString(name string, value string) error {
if strings.TrimSpace(value) == "" {
return fmt.Errorf("accessjwt: %s is required", name)
return fmt.Errorf("jwt: %s is required", name)
}
if strings.TrimSpace(value) != value {
return fmt.Errorf("accessjwt: %s must not contain surrounding whitespace", name)
return fmt.Errorf("jwt: %s must not contain surrounding whitespace", name)
}

return nil
Expand Down
56 changes: 28 additions & 28 deletions accessjwt/accessjwt_test.go → access/jwt/jwt_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package accessjwt
package jwt

import (
"context"
Expand All @@ -12,12 +12,12 @@ import (
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwt"
jwxjwt "github.com/lestrrat-go/jwx/v3/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/meigma/authkit"
"github.com/meigma/authkit/roleauth"
"github.com/meigma/authkit/authz/role"
"github.com/meigma/authkit/store/memory"
)

Expand Down Expand Up @@ -298,7 +298,7 @@ func TestVerifiedTokenUsesStorageBackedAuthorization(t *testing.T) {
loaded, err := store.FindPrincipal(context.Background(), verified.PrincipalID)
require.NoError(t, err)

authorizer, err := roleauth.NewAuthorizer(store)
authorizer, err := role.NewAuthorizer(store)
require.NoError(t, err)
decision, err := authorizer.Can(context.Background(), authkit.AuthorizationCheck{
Principal: loaded,
Expand Down Expand Up @@ -416,13 +416,13 @@ func signToken(
key := newRSAKey(t, keyID, algorithmName)
algorithm, err := signatureAlgorithm(algorithmName)
require.NoError(t, err)
token := jwt.New()
token := jwxjwt.New()
for name, value := range claims {
require.NoError(t, token.Set(name, value))
}
headers := jws.NewHeaders()
require.NoError(t, headers.Set(jws.TypeKey, tokenType))
signed, err := jwt.Sign(token, jwt.WithKey(algorithm, key, jws.WithProtectedHeaders(headers)))
signed, err := jwxjwt.Sign(token, jwxjwt.WithKey(algorithm, key, jws.WithProtectedHeaders(headers)))
require.NoError(t, err)

return string(signed)
Expand All @@ -438,7 +438,7 @@ func signTokenWithHeaders(

algorithm, err := signatureAlgorithm(DefaultAlgorithm)
require.NoError(t, err)
token := jwt.New()
token := jwxjwt.New()
for name, value := range claims {
require.NoError(t, token.Set(name, value))
}
Expand All @@ -447,7 +447,7 @@ func signTokenWithHeaders(
if mutate != nil {
mutate(headers)
}
signed, err := jwt.Sign(token, jwt.WithKey(algorithm, key, jws.WithProtectedHeaders(headers)))
signed, err := jwxjwt.Sign(token, jwxjwt.WithKey(algorithm, key, jws.WithProtectedHeaders(headers)))
require.NoError(t, err)

return string(signed)
Expand All @@ -456,7 +456,7 @@ func signTokenWithHeaders(
func signTokenWithoutType(t *testing.T, claims map[string]any) string {
t.Helper()

token := jwt.New()
token := jwxjwt.New()
for name, value := range claims {
require.NoError(t, token.Set(name, value))
}
Expand Down Expand Up @@ -492,13 +492,13 @@ func hmacSignedToken(t *testing.T, claims map[string]any) string {
key, err := jwk.Import([]byte("secret"))
require.NoError(t, err)
require.NoError(t, key.Set(jwk.KeyIDKey, testKeyID))
token := jwt.New()
token := jwxjwt.New()
for name, value := range claims {
require.NoError(t, token.Set(name, value))
}
headers := jws.NewHeaders()
require.NoError(t, headers.Set(jws.TypeKey, TokenType))
signed, err := jwt.Sign(token, jwt.WithKey(jwa.HS256(), key, jws.WithProtectedHeaders(headers)))
signed, err := jwxjwt.Sign(token, jwxjwt.WithKey(jwa.HS256(), key, jws.WithProtectedHeaders(headers)))
require.NoError(t, err)

return string(signed)
Expand All @@ -508,12 +508,12 @@ func baseClaims() map[string]any {
now := fixedTime()

return map[string]any{
jwt.IssuerKey: testIssuer,
jwt.SubjectKey: testPrincipalID,
jwt.AudienceKey: []string{testAudience},
jwt.IssuedAtKey: now,
jwt.ExpirationKey: now.Add(time.Hour),
jwt.JwtIDKey: testTokenID,
jwxjwt.IssuerKey: testIssuer,
jwxjwt.SubjectKey: testPrincipalID,
jwxjwt.AudienceKey: []string{testAudience},
jwxjwt.IssuedAtKey: now,
jwxjwt.ExpirationKey: now.Add(time.Hour),
jwxjwt.JwtIDKey: testTokenID,
}
}

Expand All @@ -540,21 +540,21 @@ func assertProtectedHeader(t *testing.T, plaintext string, tokenType string, alg
func assertNoAuthorizationClaims(t *testing.T, plaintext string, keySet jwk.Set) {
t.Helper()

token, err := jwt.Parse(
token, err := jwxjwt.Parse(
[]byte(plaintext),
jwt.WithKeySet(keySet),
jwt.WithIssuer(testIssuer),
jwt.WithAudience(testAudience),
jwt.WithClock(jwt.ClockFunc(fixedTime)),
jwxjwt.WithKeySet(keySet),
jwxjwt.WithIssuer(testIssuer),
jwxjwt.WithAudience(testAudience),
jwxjwt.WithClock(jwxjwt.ClockFunc(fixedTime)),
)
require.NoError(t, err)
assert.ElementsMatch(t, []string{
jwt.AudienceKey,
jwt.ExpirationKey,
jwt.IssuedAtKey,
jwt.IssuerKey,
jwt.JwtIDKey,
jwt.SubjectKey,
jwxjwt.AudienceKey,
jwxjwt.ExpirationKey,
jwxjwt.IssuedAtKey,
jwxjwt.IssuerKey,
jwxjwt.JwtIDKey,
jwxjwt.SubjectKey,
}, token.Keys())
assert.False(t, token.Has("roles"))
assert.False(t, token.Has("permissions"))
Expand Down
20 changes: 10 additions & 10 deletions accessjwt/keys.go → access/jwt/keys.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package accessjwt
package jwt

import (
"errors"
Expand All @@ -18,7 +18,7 @@ func signatureAlgorithms(names []string) (map[string]jwa.SignatureAlgorithm, err
for i, name := range names {
algorithm, err := signatureAlgorithm(name)
if err != nil {
return nil, fmt.Errorf("accessjwt: allowed algorithm %d: %w", i, err)
return nil, fmt.Errorf("jwt: allowed algorithm %d: %w", i, err)
}
if _, ok := algorithmMap[algorithm.String()]; ok {
continue
Expand Down Expand Up @@ -54,22 +54,22 @@ func validateKeySet(set jwk.Set, allowed map[string]jwa.SignatureAlgorithm) erro
for index := range set.Len() {
key, ok := set.Key(index)
if !ok || key == nil {
return fmt.Errorf("accessjwt: key set entry %d is required", index)
return fmt.Errorf("jwt: key set entry %d is required", index)
}
if err := validateKeyID(fmt.Sprintf("key set entry %d", index), key); err != nil {
return err
}
algorithm, ok := key.Algorithm()
if !ok {
return fmt.Errorf("accessjwt: key set entry %d algorithm is required", index)
return fmt.Errorf("jwt: key set entry %d algorithm is required", index)
}
signatureAlgorithm, ok := algorithm.(jwa.SignatureAlgorithm)
if !ok {
return fmt.Errorf("accessjwt: key set entry %d algorithm must be a signature algorithm", index)
return fmt.Errorf("jwt: key set entry %d algorithm must be a signature algorithm", index)
}
if _, ok := allowed[signatureAlgorithm.String()]; !ok {
return fmt.Errorf(
"accessjwt: key set entry %d algorithm %q is not allowed",
"jwt: key set entry %d algorithm %q is not allowed",
index,
signatureAlgorithm.String(),
)
Expand All @@ -82,10 +82,10 @@ func validateKeySet(set jwk.Set, allowed map[string]jwa.SignatureAlgorithm) erro
func validateKeyID(name string, key jwk.Key) error {
keyID, ok := key.KeyID()
if !ok || strings.TrimSpace(keyID) == "" {
return fmt.Errorf("accessjwt: %s kid is required", name)
return fmt.Errorf("jwt: %s kid is required", name)
}
if strings.TrimSpace(keyID) != keyID {
return fmt.Errorf("accessjwt: %s kid must not contain surrounding whitespace", name)
return fmt.Errorf("jwt: %s kid must not contain surrounding whitespace", name)
}

return nil
Expand All @@ -99,11 +99,11 @@ func validateOptionalKeyAlgorithm(name string, key jwk.Key, expected jwa.Signatu

signatureAlgorithm, ok := keyAlgorithm.(jwa.SignatureAlgorithm)
if !ok {
return fmt.Errorf("accessjwt: %s algorithm must be a signature algorithm", name)
return fmt.Errorf("jwt: %s algorithm must be a signature algorithm", name)
}
if signatureAlgorithm.String() != expected.String() {
return fmt.Errorf(
"accessjwt: %s algorithm %q does not match %q",
"jwt: %s algorithm %q does not match %q",
name,
signatureAlgorithm.String(),
expected.String(),
Expand Down
2 changes: 1 addition & 1 deletion accessjwt/types.go → access/jwt/types.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package accessjwt
package jwt

import (
"time"
Expand Down
32 changes: 16 additions & 16 deletions accessjwt/verifier.go → access/jwt/verifier.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package accessjwt
package jwt

import (
"context"
Expand All @@ -9,7 +9,7 @@ import (
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwt"
jwxjwt "github.com/lestrrat-go/jwx/v3/jwt"

"github.com/meigma/authkit"
)
Expand All @@ -33,10 +33,10 @@ func NewVerifier(opts VerifierOptions) (*Verifier, error) {
return nil, err
}
if opts.KeySet == nil || opts.KeySet.Len() == 0 {
return nil, errors.New("accessjwt: key set is required")
return nil, errors.New("jwt: key set is required")
}
if opts.AcceptableSkew < 0 {
return nil, errors.New("accessjwt: acceptable skew must not be negative")
return nil, errors.New("jwt: acceptable skew must not be negative")
}

algorithmMap, err := signatureAlgorithms(opts.AllowedAlgorithms)
Expand All @@ -49,7 +49,7 @@ func NewVerifier(opts VerifierOptions) (*Verifier, error) {

keySet, err := opts.KeySet.Clone()
if err != nil {
return nil, fmt.Errorf("accessjwt: clone key set: %w", err)
return nil, fmt.Errorf("jwt: clone key set: %w", err)
}

clock := opts.Clock
Expand Down Expand Up @@ -80,17 +80,17 @@ func (v *Verifier) VerifyToken(ctx context.Context, plaintext string) (VerifiedT
return VerifiedToken{}, unauthenticated(err.Error())
}

token, err := jwt.Parse(
token, err := jwxjwt.Parse(
[]byte(plaintext),
jwt.WithKeySet(v.keySet),
jwt.WithIssuer(v.issuer),
jwt.WithAudience(v.audience),
jwt.WithRequiredClaim(jwt.SubjectKey),
jwt.WithRequiredClaim(jwt.JwtIDKey),
jwt.WithRequiredClaim(jwt.IssuedAtKey),
jwt.WithRequiredClaim(jwt.ExpirationKey),
jwt.WithClock(jwt.ClockFunc(v.clock)),
jwt.WithAcceptableSkew(v.acceptableSkew),
jwxjwt.WithKeySet(v.keySet),
jwxjwt.WithIssuer(v.issuer),
jwxjwt.WithAudience(v.audience),
jwxjwt.WithRequiredClaim(jwxjwt.SubjectKey),
jwxjwt.WithRequiredClaim(jwxjwt.JwtIDKey),
jwxjwt.WithRequiredClaim(jwxjwt.IssuedAtKey),
jwxjwt.WithRequiredClaim(jwxjwt.ExpirationKey),
jwxjwt.WithClock(jwxjwt.ClockFunc(v.clock)),
jwxjwt.WithAcceptableSkew(v.acceptableSkew),
)
if err != nil {
return VerifiedToken{}, unauthenticated("JWT verification failed")
Expand Down Expand Up @@ -141,7 +141,7 @@ func (v *Verifier) validateProtectedHeaders(raw []byte) error {
return nil
}

func (v *Verifier) verifiedToken(token jwt.Token) (VerifiedToken, error) {
func (v *Verifier) verifiedToken(token jwxjwt.Token) (VerifiedToken, error) {
principalID, ok := token.Subject()
if !ok || principalID == "" {
return VerifiedToken{}, errors.New("subject claim is required")
Expand Down
Loading