Skip to content

Commit

Permalink
add validator for x509 credential attributes (#3617)
Browse files Browse the repository at this point in the history
  • Loading branch information
woutslakhorst authored Dec 20, 2024
1 parent e1f5c2b commit 3e179ce
Show file tree
Hide file tree
Showing 14 changed files with 326 additions and 22 deletions.
2 changes: 1 addition & 1 deletion auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (auth *Auth) Configure(config core.ServerConfig) error {
auth.authzServer = oauth.NewAuthorizationServer(auth.vdrInstance.Resolver(), auth.vcr, auth.vcr.Verifier(), auth.serviceResolver,
auth.keyStore, auth.contractNotary, auth.jsonldManager, accessTokenLifeSpan)
auth.relyingParty = oauth.NewRelyingParty(auth.vdrInstance.Resolver(), auth.serviceResolver,
auth.keyStore, auth.vcr.Wallet(), auth.httpClientTimeout, auth.tlsConfig, config.Strictmode)
auth.keyStore, auth.vcr.Wallet(), auth.httpClientTimeout, auth.tlsConfig, config.Strictmode, auth.pkiProvider)

if err := auth.authzServer.Configure(auth.config.ClockSkew, config.Strictmode); err != nil {
return err
Expand Down
7 changes: 5 additions & 2 deletions auth/services/oauth/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/nuts-foundation/nuts-node/pki"
"net/url"
"strings"
"time"
Expand Down Expand Up @@ -50,12 +51,13 @@ type relyingParty struct {
httpClientTimeout time.Duration
httpClientTLS *tls.Config
wallet holder.Wallet
pkiValidator pki.Validator
}

// NewRelyingParty returns an implementation of RelyingParty
func NewRelyingParty(
didResolver resolver.DIDResolver, serviceResolver didman.CompoundServiceResolver, privateKeyStore nutsCrypto.KeyStore,
wallet holder.Wallet, httpClientTimeout time.Duration, httpClientTLS *tls.Config, strictMode bool) RelyingParty {
wallet holder.Wallet, httpClientTimeout time.Duration, httpClientTLS *tls.Config, strictMode bool, pkiValidator pki.Validator) RelyingParty {
return &relyingParty{
keyResolver: resolver.DIDKeyResolver{Resolver: didResolver},
serviceResolver: serviceResolver,
Expand All @@ -64,6 +66,7 @@ func NewRelyingParty(
httpClientTLS: httpClientTLS,
strictMode: strictMode,
wallet: wallet,
pkiValidator: pkiValidator,
}
}

Expand All @@ -81,7 +84,7 @@ func (s *relyingParty) CreateJwtGrant(ctx context.Context, request services.Crea
}

for _, verifiableCredential := range request.Credentials {
validator := credential.FindValidator(verifiableCredential)
validator := credential.FindValidator(verifiableCredential, s.pkiValidator)
if err := validator.Validate(verifiableCredential); err != nil {
return nil, fmt.Errorf("invalid VerifiableCredential: %w", err)
}
Expand Down
5 changes: 4 additions & 1 deletion vcr/credential/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,22 @@ import (
"github.com/nuts-foundation/go-did/did"
"github.com/nuts-foundation/go-did/vc"
"github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/pki"
"github.com/nuts-foundation/nuts-node/vcr/signature/proof"
)

// FindValidator finds the Validator the provided credential based on its Type
// When no additional type is provided, it returns the default validator
func FindValidator(credential vc.VerifiableCredential) Validator {
func FindValidator(credential vc.VerifiableCredential, pkiValidator pki.Validator) Validator {
if vcTypes := ExtractTypes(credential); len(vcTypes) > 0 {
for _, t := range vcTypes {
switch t {
case NutsOrganizationCredentialType:
return nutsOrganizationCredentialValidator{}
case NutsAuthorizationCredentialType:
return nutsAuthorizationCredentialValidator{}
case X509CredentialType:
return x509CredentialValidator{pkiValidator: pkiValidator}
}
}
}
Expand Down
14 changes: 9 additions & 5 deletions vcr/credential/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@ import (

func TestFindValidator(t *testing.T) {
t.Run("an unknown type returns the default validator", func(t *testing.T) {
assert.IsType(t, defaultCredentialValidator{}, FindValidator(vc.VerifiableCredential{}))
assert.IsType(t, defaultCredentialValidator{}, FindValidator(vc.VerifiableCredential{}, nil))
})

t.Run("validator and builder found for NutsOrganizationCredential", func(t *testing.T) {
assert.IsType(t, nutsOrganizationCredentialValidator{}, FindValidator(test.ValidNutsOrganizationCredential(t)))
t.Run("validator found for NutsOrganizationCredential", func(t *testing.T) {
assert.IsType(t, nutsOrganizationCredentialValidator{}, FindValidator(test.ValidNutsOrganizationCredential(t), nil))
})

t.Run("validator and builder found for NutsAuthorizationCredential", func(t *testing.T) {
assert.IsType(t, nutsAuthorizationCredentialValidator{}, FindValidator(test.ValidNutsAuthorizationCredential(t)))
t.Run("validator found for NutsAuthorizationCredential", func(t *testing.T) {
assert.IsType(t, nutsAuthorizationCredentialValidator{}, FindValidator(test.ValidNutsAuthorizationCredential(t), nil))
})

t.Run("validator found for X509Credential", func(t *testing.T) {
assert.IsType(t, x509CredentialValidator{}, FindValidator(test.ValidX509Credential(t), nil))
})
}

Expand Down
2 changes: 2 additions & 0 deletions vcr/credential/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ const (
NutsOrganizationCredentialType = "NutsOrganizationCredential"
// NutsAuthorizationCredentialType is the VC type for a NutsAuthorizationCredential
NutsAuthorizationCredentialType = "NutsAuthorizationCredential"
// X509CredentialType is the VC type for a X509Credential
X509CredentialType = "X509Credential"
// NutsV1Context is the nuts V1 json-ld context
NutsV1Context = "https://nuts.nl/credentials/v1"
)
Expand Down
111 changes: 109 additions & 2 deletions vcr/credential/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@
package credential

import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/nuts-foundation/go-did/did"
"github.com/nuts-foundation/go-did/vc"
"github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/pki"
"github.com/nuts-foundation/nuts-node/vcr/revocation"
"github.com/nuts-foundation/nuts-node/vdr/didx509"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"net/url"
"strings"
)

// Validator is the interface specific VC verification.
Expand Down Expand Up @@ -253,3 +259,104 @@ func validateNutsCredentialID(credential vc.VerifiableCredential) error {
}
return nil
}

// x509CredentialValidator checks the did:x509 issuer and if the credentialSubject claims match the x509 certificate
type x509CredentialValidator struct {
pkiValidator pki.Validator
}

func (d x509CredentialValidator) Validate(credential vc.VerifiableCredential) error {
didX509Issuer, err := did.ParseDID(credential.Issuer.String())
if err != nil {
return errors.Join(errValidation, err)
}
x509resolver := didx509.NewResolver()
resolveMetadata := resolver.ResolveMetadata{}
if credential.Format() == vc.JWTCredentialProofFormat {
headers, err := crypto.ExtractProtectedHeaders(credential.Raw())
if err != nil {
// theoretically impossible, since the credential is already parsed
return fmt.Errorf("%w: invalid JWT headers: %w", errValidation, err)
}
resolveMetadata.JwtProtectedHeaders = headers
} else {
// unsupported format
return fmt.Errorf("%w: unsupported credential format: %s", errValidation, credential.Format())
}
_, _, err = x509resolver.Resolve(*didX509Issuer, &resolveMetadata)
if err != nil {
return fmt.Errorf("%w: invalid issuer: %w", errValidation, err)
}

if err = validatePolicyAssertions(*didX509Issuer, credential); err != nil {
return fmt.Errorf("%w: %w", errValidation, err)
}

chainHeader, _ := resolveMetadata.GetProtectedHeaderChain(jwk.X509CertChainKey) // already succeeded for resolve
// convert cert.Chain to []*x509.Certificate
chain := make([]*x509.Certificate, chainHeader.Len())
for i := 0; i < chainHeader.Len(); i++ {
base64Cert, _ := chainHeader.Get(i)
// these two operations can't fail since the resolve earlier already succeeded
der, _ := base64.StdEncoding.DecodeString(string(base64Cert))
cert, _ := x509.ParseCertificate(der)
chain[i] = cert
}
if err = d.pkiValidator.CheckCRL(chain); err != nil {
return fmt.Errorf("%w: %w", errValidation, err)
}

return (defaultCredentialValidator{}).Validate(credential)
}

// validatePolicyAssertions checks if the credentialSubject claims match the did issuer policies
func validatePolicyAssertions(issuer did.DID, credential vc.VerifiableCredential) error {
// get base form of all credentialSubject
var target = make([]map[string]interface{}, 1)
if err := credential.UnmarshalCredentialSubject(&target); err != nil {
return err
}

// we create a map of policyName to policyValue, then we split the policyValue into another map
// no checks required, this has been done by the did:x509 resolver
x509DID, _ := didx509.ParseX509Did(issuer)
policyMap := make(map[string]map[string]string)
for _, policy := range x509DID.Policies {
policySplit := strings.Split(policy.Value, ":")
policyName := string(policy.Name)
policyMap[policyName] = make(map[string]string)
// bounds checked by ParseX509Did
for i := 0; i < len(policySplit); i += 2 {
unscaped, _ := url.PathUnescape(policySplit[i+1])
policyMap[policyName][policySplit[i]] = unscaped
}
}

// we usually don't use multiple credentialSubjects, but for this validation it doesn't matter
for _, credentialSubject := range target {
// remove id from target
delete(credentialSubject, "id")

// for each assertion create a string as "%s:%s" with key/value
// check if the resulting string is present in the policyString
for key, value := range credentialSubject {
split := strings.Split(key, ":")
if len(split) != 2 {
return fmt.Errorf("invalid credentialSubject assertion name '%s'", key)
}
policyValueMap, ok := policyMap[split[0]]
if !ok {
return fmt.Errorf("policy '%s' not found in did:x509 policy", split[0])
}
policyValue, ok := policyValueMap[split[1]]
if !ok {
return fmt.Errorf("assertion '%s' not found in did:x509 policy", key)
}
if value != policyValue {
return fmt.Errorf("invalid assertion value '%s' for '%s' did:x509 policy", value, key)
}
}
}

return nil
}
130 changes: 128 additions & 2 deletions vcr/credential/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,21 @@
package credential

import (
"github.com/nuts-foundation/nuts-node/pki"
"go.uber.org/mock/gomock"
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jwt"
ssi "github.com/nuts-foundation/go-did"
"github.com/nuts-foundation/go-did/did"
"github.com/nuts-foundation/go-did/vc"
"github.com/nuts-foundation/nuts-node/jsonld"
"github.com/nuts-foundation/nuts-node/vcr/revocation"
"github.com/nuts-foundation/nuts-node/vcr/test"
"github.com/nuts-foundation/nuts-node/vdr"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func init() {
Expand Down Expand Up @@ -493,3 +498,124 @@ func Test_validateCredentialStatus(t *testing.T) {
})
})
}

func TestX509CredentialValidator_Validate(t *testing.T) {
ctx := createTestContext(t)

t.Run("ok", func(t *testing.T) {
x509credential := test.ValidX509Credential(t)
ctx := createTestContext(t)
ctx.pkiValidator.EXPECT().CheckCRL(gomock.Any()).Return(nil)

err := ctx.validator.Validate(x509credential)

assert.NoError(t, err)
})
t.Run("CRL check failed", func(t *testing.T) {
x509credential := test.ValidX509Credential(t)
ctx := createTestContext(t)
ctx.pkiValidator.EXPECT().CheckCRL(gomock.Any()).Return(assert.AnError)

err := ctx.validator.Validate(x509credential)

assert.ErrorIs(t, err, errValidation)
assert.ErrorIs(t, err, assert.AnError)
})
t.Run("invalid did", func(t *testing.T) {
x509credential := vc.VerifiableCredential{Issuer: ssi.MustParseURI("not_a_did")}

err := ctx.validator.Validate(x509credential)

assert.ErrorIs(t, err, errValidation)
assert.ErrorIs(t, err, did.ErrInvalidDID)
})
t.Run("invalid format", func(t *testing.T) {
x509credential := vc.VerifiableCredential{Issuer: ssi.MustParseURI("did:example:123")}

err := ctx.validator.Validate(x509credential)

assert.ErrorIs(t, err, errValidation)
assert.ErrorContains(t, err, "unsupported credential format")
})
t.Run("invalid did:x509", func(t *testing.T) {
x509credential := test.ValidX509Credential(t, func(builder *jwt.Builder) *jwt.Builder {
builder.Issuer("did:example:123")
return builder
})

err := ctx.validator.Validate(x509credential)

assert.ErrorIs(t, err, errValidation)
assert.ErrorContains(t, err, "invalid issuer")
})

t.Run("failed validation", func(t *testing.T) {

testCases := []struct {
name string
claim interface{}
expectedError string
}{
{
name: "invalid assertion value",
claim: map[string]interface{}{
"san:otherName": "A_BIG_STRIN",
},
expectedError: "invalid assertion value 'A_BIG_STRIN' for 'san:otherName' did:x509 policy",
},
{
name: "invalid assertion name",
claim: map[string]interface{}{
"san": "A_BIG_STRING",
},
expectedError: "invalid credentialSubject assertion name 'san'",
},
{
name: "unknown assertion",
claim: map[string]interface{}{
"san:ip": "10.0.0.1",
},
expectedError: "assertion 'san:ip' not found in did:x509 policy",
},
{
name: "unknown policy",
claim: map[string]interface{}{
"stan:ip": "10.0.0.1",
},
expectedError: "policy 'stan' not found in did:x509 policy",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
x509credential := test.ValidX509Credential(t, func(builder *jwt.Builder) *jwt.Builder {
builder.Claim("vc", map[string]interface{}{
"credentialSubject": tc.claim,
})
return builder
})

err := ctx.validator.Validate(x509credential)

assert.ErrorIs(t, err, errValidation)
assert.ErrorContains(t, err, tc.expectedError)
})
}
})
}

type testContext struct {
ctrl *gomock.Controller
validator x509CredentialValidator
pkiValidator *pki.MockValidator
}

func createTestContext(t *testing.T) testContext {
ctrl := gomock.NewController(t)
pkiValidator := pki.NewMockValidator(ctrl)
return testContext{
ctrl: ctrl,
validator: x509CredentialValidator{pkiValidator: pkiValidator},
pkiValidator: pkiValidator,
}
}
3 changes: 2 additions & 1 deletion vcr/issuer/issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ func (i issuer) Issue(ctx context.Context, template vc.VerifiableCredential, opt
}

// Validate the VC using the type-specific validator
validator := credential.FindValidator(*createdVC)
// we don't pass a pki.Validator since we don't issue x509 certs
validator := credential.FindValidator(*createdVC, nil)
if err := validator.Validate(*createdVC); err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 3e179ce

Please sign in to comment.