Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IAM: Add caching to HTTP client #3148

Merged
merged 12 commits into from
Jun 4, 2024
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ The following options can be configured on the server:
http.internal.auth.type Whether to enable authentication for /internal endpoints, specify 'token_v2' for bearer token mode or 'token' for legacy bearer token mode.
http.public.address \:8080 Address and port the server will be listening to for public-facing endpoints.
**JSONLD**
jsonld.contexts.localmapping [https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson,https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist.
jsonld.contexts.localmapping [https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson,https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson,https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist.
jsonld.contexts.remoteallowlist [https://schema.org,https://www.w3.org/2018/credentials/v1,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json,https://w3id.org/vc/status-list/2021/v1] In strict mode, fetching external JSON-LD contexts is not allowed except for context-URLs listed here.
**PKI**
pki.maxupdatefailhours 4 Maximum number of hours that a denylist update can fail
Expand Down Expand Up @@ -238,6 +238,7 @@ If your use case does not use these features, you can ignore this table.
auth.accesstokenlifespan 60 defines how long (in seconds) an access token is valid. Uses default in strict mode.
auth.clockskew 5000 allowed JWT Clock skew in milliseconds
auth.contractvalidators [irma,dummy,employeeid] sets the different contract validators to use
auth.http.cache.maxbytes 10485760 HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses.
auth.irma.autoupdateschemas true set if you want automatically update the IRMA schemas every 60 minutes.
auth.irma.schememanager pbdf IRMA schemeManager to use for attributes. Can be either 'pbdf' or 'irma-demo'.
**Events**
Expand Down
15 changes: 10 additions & 5 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type Auth struct {
strictMode bool
httpClientTimeout time.Duration
tlsConfig *tls.Config
iamClient *iam.OpenID4VPClient
}

// Name returns the name of the module.
Expand Down Expand Up @@ -107,8 +108,7 @@ func (auth *Auth) RelyingParty() oauth.RelyingParty {
}

func (auth *Auth) IAMClient() iam.Client {
keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()}
return iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, auth.httpClientTimeout)
return auth.iamClient
}

// Configure the Auth struct by creating a validator and create an Irma server
Expand Down Expand Up @@ -146,23 +146,28 @@ func (auth *Auth) Configure(config core.ServerConfig) error {
return err
}

var httpClientTimeout time.Duration
if auth.config.HTTPTimeout >= 0 {
auth.httpClientTimeout = time.Duration(auth.config.HTTPTimeout) * time.Second
httpClientTimeout = time.Duration(auth.config.HTTPTimeout) * time.Second
} else {
// auth.http.config got deprecated in favor of httpclient.timeout
auth.httpClientTimeout = config.HTTPClient.Timeout
httpClientTimeout = config.HTTPClient.Timeout
}
// V1 API related stuff
accessTokenLifeSpan := time.Duration(auth.config.AccessTokenLifeSpan) * time.Second
auth.authzServer = oauth.NewAuthorizationServer(auth.vdrInstance.Resolver(), auth.vcr, auth.vcr.Verifier(), auth.serviceResolver,
auth.keyStore, auth.contractNotary, auth.jsonldManager, accessTokenLifeSpan)
auth.relyingParty = oauth.NewRelyingParty(auth.vdrInstance.Resolver(), auth.serviceResolver,
auth.keyStore, auth.vcr.Wallet(), auth.httpClientTimeout, auth.tlsConfig, config.Strictmode)
auth.keyStore, auth.vcr.Wallet(), httpClientTimeout, auth.tlsConfig, config.Strictmode)

if err := auth.authzServer.Configure(auth.config.ClockSkew, config.Strictmode); err != nil {
return err
}

keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()}
auth.iamClient = iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode,
httpClientTimeout, auth.config.HTTPResponseCacheSize)

return nil
}

Expand Down
4 changes: 3 additions & 1 deletion auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ func TestAuth_IAMClient(t *testing.T) {
config := DefaultConfig()
config.ContractValidators = []string{"dummy"}
ctrl := gomock.NewController(t)
pkiMock := pki.NewMockProvider(ctrl) // no calls are expected
pkiMock := pki.NewMockProvider(ctrl)
pkiMock.EXPECT().CreateTLSConfig(gomock.Any()) // for v5 HTTP client
vdrInstance := vdr.NewMockVDR(ctrl)
vdrInstance.EXPECT().Resolver().AnyTimes()

i := NewAuthInstance(config, vdrInstance, vcr.NewTestVCRInstance(t), crypto.NewMemoryCryptoInstance(), nil, nil, pkiMock)
require.NoError(t, i.Configure(core.TestServerConfig()))

assert.NotNil(t, i.IAMClient())
})
Expand Down
214 changes: 214 additions & 0 deletions auth/client/iam/caching.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
* Copyright (C) 2024 Nuts community
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/

package iam

import (
"bytes"
"fmt"
"github.com/nuts-foundation/nuts-node/auth/log"
"github.com/nuts-foundation/nuts-node/core"
"github.com/pquerna/cachecontrol"
"io"
"net/http"
"net/url"
"sync"
"time"
)

// maxCacheTime is the maximum time responses are cached.
// Even if the server responds with a longer cache time, responses are never cached longer than maxCacheTime.
const maxCacheTime = time.Hour

// CachingHTTPRequestDoer is a cache for HTTP responses for DID/OAuth2/OpenID clients.
// It only caches GET requests (since generally only metadata is cacheable), and only if the response is cacheable.
// It only works on expiration time and does not respect ETags headers.
// When maxBytes is reached, the entries that expire first are removed to make room for new entries (since those are the first ones to be pruned any ways).
type CachingHTTPRequestDoer struct {
maxBytes int
requestDoer core.HTTPRequestDoer

// currentSizeBytes is the current size of the cache in bytes.
// It's used to make room for new entries when the cache is full.
currentSizeBytes int
// head is the first entry of a linked list of cache entries, ordered by expiration time.
// The first entry is the one that will expire first, which optimizes the removal of expired entries.
// When an entry is inserted in the cache, it's inserted in the right place in the linked list (ordered by expiry).
head *cacheEntry
// entriesByURL is a map of cache entries, indexed by the URL of the request.
// This optimizes the lookup of cache entries by URL.
entriesByURL map[string][]*cacheEntry
mux sync.RWMutex
}

type cacheEntry struct {
responseData []byte
requestURL *url.URL
requestMethod string
requestRawQuery string
expirationTime time.Time
next *cacheEntry
responseStatus int
responseHeaders http.Header
}

func (h *CachingHTTPRequestDoer) Do(httpRequest *http.Request) (*http.Response, error) {
if httpRequest.Method == http.MethodGet {
if response := h.cachedEntry(httpRequest); response != nil {
return response, nil
}
}
httpResponse, err := h.requestDoer.Do(httpRequest)
if err != nil {
return nil, err
}
err = h.cacheResponse(httpRequest, httpResponse)
if err != nil {
return nil, err
}
return httpResponse, nil
}

// cacheResponse caches the response if it's cacheable.
func (h *CachingHTTPRequestDoer) cacheResponse(httpRequest *http.Request, httpResponse *http.Response) error {
if httpRequest.Method != http.MethodGet {
return nil
}
reasons, expirationTime, err := cachecontrol.CachableResponse(httpRequest, httpResponse, cachecontrol.Options{PrivateCache: false})
if err != nil {
log.Logger().WithError(err).Infof("error while checking cacheability of response (url=%s), not caching", httpRequest.URL.String())
return nil
}
// We don't want to cache responses for too long, as that increases the risk of staleness,
// and could keep cause very long-lived entries to never be pruned.
maxExpirationTime := time.Now().Add(maxCacheTime)
if expirationTime.After(maxExpirationTime) {
expirationTime = maxExpirationTime
}
if len(reasons) > 0 || expirationTime.IsZero() {
log.Logger().Debugf("response (url=%s) is not cacheable: %v", httpRequest.URL.String(), reasons)
return nil
}
responseBytes, err := io.ReadAll(httpResponse.Body)
if err != nil {
return fmt.Errorf("error while reading response body for caching: %w", err)
}
h.mux.Lock()
defer h.mux.Unlock()
if len(responseBytes) <= h.maxBytes { // sanity check
h.insert(&cacheEntry{
responseData: responseBytes,
requestMethod: httpRequest.Method,
requestURL: httpRequest.URL,
requestRawQuery: httpRequest.URL.RawQuery,
responseStatus: httpResponse.StatusCode,
responseHeaders: httpResponse.Header,
expirationTime: expirationTime,
})
}
httpResponse.Body = io.NopCloser(bytes.NewReader(responseBytes))
return nil
}

// cachedEntry returns a cached response if it exists.
func (h *CachingHTTPRequestDoer) cachedEntry(httpRequest *http.Request) *http.Response {
h.mux.Lock()
defer h.mux.Unlock()
h.removeExpiredEntries()
// Find cached response
entries := h.entriesByURL[httpRequest.URL.String()]
for _, entry := range entries {
if entry.requestMethod == httpRequest.Method && entry.requestRawQuery == httpRequest.URL.RawQuery {
return &http.Response{
StatusCode: entry.responseStatus,
Header: entry.responseHeaders,
Body: io.NopCloser(bytes.NewReader(entry.responseData)),
}
}
}
return nil
}

func (h *CachingHTTPRequestDoer) removeExpiredEntries() {
var current = h.head
for current != nil {
if current.expirationTime.Before(time.Now()) {
current = h.pop()
} else {
break
}
}
}

// insert adds a new entry to the cache.
func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) {
// See if we need to make room for the new entry
for h.currentSizeBytes+len(entry.responseData) >= h.maxBytes {
reinkrul marked this conversation as resolved.
Show resolved Hide resolved
_ = h.pop()
}
if h.head == nil {
// First entry
h.head = entry
} else {
// Insert in the linked list, ordered by expiration time
var current = h.head
for current.next != nil && current.next.expirationTime.Before(entry.expirationTime) {
current = current.next
}
if current == h.head {
h.head = entry
}
entry.next = current.next
current.next = entry
}
// Insert in the URL map for quick lookup
h.entriesByURL[entry.requestURL.String()] = append(h.entriesByURL[entry.requestURL.String()], entry)

h.currentSizeBytes += len(entry.responseData)
}

// pop removes the first entry from the linked list
func (h *CachingHTTPRequestDoer) pop() *cacheEntry {
if h.head == nil {
return nil
}
requestURL := h.head.requestURL.String()
entries := h.entriesByURL[requestURL]
for i, entry := range entries {
if entry == h.head {
h.entriesByURL[requestURL] = append(entries[:i], entries[i+1:]...)
if len(h.entriesByURL[requestURL]) == 0 {
delete(h.entriesByURL, requestURL)
}
break
}
}
h.currentSizeBytes -= len(h.head.responseData)
h.head = h.head.next
return h.head
}

// cachingHTTPClient
func cachingHTTPClient(requestDoer core.HTTPRequestDoer, responsesCacheSize int) *CachingHTTPRequestDoer {
return &CachingHTTPRequestDoer{
maxBytes: responsesCacheSize,
requestDoer: requestDoer,
entriesByURL: map[string][]*cacheEntry{},
mux: sync.RWMutex{},
}
}
Loading