Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert lib/auth to use slog #50577

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2ca18a9
Convert auth_with_roles to use slog
rosstimothy Dec 21, 2024
6f04a56
allow benchmark tests to log
rosstimothy Dec 21, 2024
2188b73
remove last use of logrus in github.go
rosstimothy Dec 21, 2024
e8ebda7
remove last use of logrus in rotate.go
rosstimothy Dec 21, 2024
3c90823
remove auth package logrus logger
rosstimothy Dec 21, 2024
475ea24
convert auth.go to exclusively use slog
rosstimothy Dec 21, 2024
a507ec8
convert db.go to exclusively use slog
rosstimothy Dec 21, 2024
85342cf
convert methods.go to exclusively use slog
rosstimothy Dec 26, 2024
af955cf
convert password.go to exclusively use slog
rosstimothy Dec 26, 2024
1f8cf97
convert join_ec2.go to exclusively use slog
rosstimothy Dec 26, 2024
8963371
convert join_iam.go to exclusively use slog
rosstimothy Dec 26, 2024
7cdf744
convert kube.go to exclusively use slog
rosstimothy Dec 26, 2024
63a2f3c
convert oidc.go to exclusively use slog
rosstimothy Dec 26, 2024
7d03e16
convert saml.go to exclusively use slog
rosstimothy Dec 26, 2024
75703e9
convert sso_diag_context.go to exclusively use slog
rosstimothy Dec 26, 2024
421aa8a
convert transport_credentials.go to exclusively use slog
rosstimothy Dec 26, 2024
1ef4f17
convert trustedcluster.go to exclusively use slog
rosstimothy Dec 26, 2024
aab97b3
convert usertoken.go to exclusively use slog
rosstimothy Dec 26, 2024
3b01850
convert user.go to exclusively use slog
rosstimothy Dec 26, 2024
a690266
remove logrus import
rosstimothy Dec 26, 2024
256dacc
fix: pass in context
rosstimothy Dec 26, 2024
8d00511
fix: remove new use of package logrus logger
rosstimothy Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/auth/accountrecovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (a *Server) VerifyAccountRecovery(ctx context.Context, req *proto.VerifyAcc
return nil, trace.AccessDenied(verifyRecoveryBadAuthnErrMsg)
}

if err := a.verifyUserToken(startToken, authclient.UserTokenTypeRecoveryStart); err != nil {
if err := a.verifyUserToken(ctx, startToken, authclient.UserTokenTypeRecoveryStart); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -304,7 +304,7 @@ func (a *Server) CompleteAccountRecovery(ctx context.Context, req *proto.Complet
return trace.AccessDenied(completeRecoveryGenericErrMsg)
}

if err := a.verifyUserToken(approvedToken, authclient.UserTokenTypeRecoveryApproved); err != nil {
if err := a.verifyUserToken(ctx, approvedToken, authclient.UserTokenTypeRecoveryApproved); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -403,7 +403,7 @@ func (a *Server) CreateAccountRecoveryCodes(ctx context.Context, req *proto.Crea
return nil, trace.AccessDenied("only local users may create recovery codes")
}

if err := a.verifyUserToken(token, authclient.UserTokenTypeRecoveryApproved, authclient.UserTokenTypePrivilege); err != nil {
if err := a.verifyUserToken(ctx, token, authclient.UserTokenTypeRecoveryApproved, authclient.UserTokenTypePrivilege); err != nil {
return nil, trace.Wrap(err)
}

Expand All @@ -428,7 +428,7 @@ func (a *Server) GetAccountRecoveryToken(ctx context.Context, req *proto.GetAcco
return nil, trace.AccessDenied("access denied")
}

if err := a.verifyUserToken(token, authclient.UserTokenTypeRecoveryStart, authclient.UserTokenTypeRecoveryApproved); err != nil {
if err := a.verifyUserToken(ctx, token, authclient.UserTokenTypeRecoveryStart, authclient.UserTokenTypeRecoveryApproved); err != nil {
return nil, trace.Wrap(err)
}

Expand Down
230 changes: 133 additions & 97 deletions lib/auth/auth.go

Large diffs are not rendered by default.

245 changes: 159 additions & 86 deletions lib/auth/auth_with_roles.go

Large diffs are not rendered by default.

21 changes: 0 additions & 21 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"crypto/tls"
"crypto/x509/pkix"
"fmt"
"io"
"net/url"
"slices"
"strconv"
Expand All @@ -38,7 +37,6 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/pquerna/otp/totp"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -77,7 +75,6 @@ import (
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/srv/discovery/common"
"github.com/gravitational/teleport/lib/tlsca"
logutils "github.com/gravitational/teleport/lib/utils/log"
"github.com/gravitational/teleport/lib/utils/pagination"
)

Expand Down Expand Up @@ -1824,12 +1821,6 @@ func BenchmarkListNodes(b *testing.B) {
const nodeCount = 50_000
const roleCount = 32

logger := logrus.StandardLogger()
logger.ReplaceHooks(make(logrus.LevelHooks))
logrus.SetFormatter(logutils.NewTestJSONFormatter())
logger.SetLevel(logrus.DebugLevel)
logger.SetOutput(io.Discard)

ctx := context.Background()
srv := newTestTLSServer(b)

Expand Down Expand Up @@ -6124,12 +6115,6 @@ func BenchmarkListUnifiedResourcesFilter(b *testing.B) {
const nodeCount = 150_000
const roleCount = 32

logger := logrus.StandardLogger()
logger.ReplaceHooks(make(logrus.LevelHooks))
logrus.SetFormatter(logutils.NewTestJSONFormatter())
logger.SetLevel(logrus.PanicLevel)
logger.SetOutput(io.Discard)

ctx := context.Background()
srv := newTestTLSServer(b)

Expand Down Expand Up @@ -6257,12 +6242,6 @@ func BenchmarkListUnifiedResources(b *testing.B) {
const nodeCount = 150_000
const roleCount = 32

logger := logrus.StandardLogger()
logger.ReplaceHooks(make(logrus.LevelHooks))
logrus.SetFormatter(logutils.NewTestJSONFormatter())
logger.SetLevel(logrus.DebugLevel)
logger.SetOutput(io.Discard)

ctx := context.Background()
srv := newTestTLSServer(b)

Expand Down
11 changes: 7 additions & 4 deletions lib/auth/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (a *Server) SignDatabaseCSR(ctx context.Context, req *proto.DatabaseCSRRequ
"this Teleport cluster is not licensed for database access, please contact the cluster administrator")
}

log.Debugf("Signing database CSR for cluster %v.", req.ClusterName)
a.logger.DebugContext(ctx, "Signing database CSR for cluster", "cluster", req.ClusterName)

clusterName, err := a.GetClusterName()
if err != nil {
Expand Down Expand Up @@ -348,7 +348,7 @@ func (a *Server) GenerateSnowflakeJWT(ctx context.Context, req *proto.SnowflakeJ
return nil, trace.Wrap(err)
}

subject, issuer := getSnowflakeJWTParams(req.AccountName, req.UserName, pubKey)
subject, issuer := getSnowflakeJWTParams(ctx, req.AccountName, req.UserName, pubKey)

_, signer, err := a.GetKeyStore().GetTLSCertAndSigner(ctx, ca)
if err != nil {
Expand All @@ -371,7 +371,7 @@ func (a *Server) GenerateSnowflakeJWT(ctx context.Context, req *proto.SnowflakeJ
}, nil
}

func getSnowflakeJWTParams(accountName, userName string, publicKey []byte) (string, string) {
func getSnowflakeJWTParams(ctx context.Context, accountName, userName string, publicKey []byte) (string, string) {
// Use only the first part of the account name to generate JWT
// Based on:
// https://github.com/snowflakedb/snowflake-connector-python/blob/f2f7e6f35a162484328399c8a50a5015825a5573/src/snowflake/connector/auth_keypair.py#L83
Expand All @@ -383,7 +383,10 @@ func getSnowflakeJWTParams(accountName, userName string, publicKey []byte) (stri
accnToken, _, _ := strings.Cut(accountName, accNameSeparator)
accnTokenCap := strings.ToUpper(accnToken)
userNameCap := strings.ToUpper(userName)
log.Debugf("Signing database JWT token for %s %s", accnTokenCap, userNameCap)
logger.DebugContext(ctx, "Signing database JWT token",
"account_name", accnTokenCap,
"user_name", userNameCap,
)

subject := fmt.Sprintf("%s.%s", accnTokenCap, userNameCap)

Expand Down
2 changes: 1 addition & 1 deletion lib/auth/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func Test_getSnowflakeJWTParams(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
subject, issuer := getSnowflakeJWTParams(tt.args.accountName, tt.args.userName, tt.args.publicKey)
subject, issuer := getSnowflakeJWTParams(context.Background(), tt.args.accountName, tt.args.userName, tt.args.publicKey)

require.Equal(t, tt.wantSubject, subject)
require.Equal(t, tt.wantIssuer, issuer)
Expand Down
3 changes: 1 addition & 2 deletions lib/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -349,7 +348,7 @@ func orgUsesExternalSSO(ctx context.Context, endpointURL, org string, client htt
if resp != nil {
io.Copy(io.Discard, resp.Body)
if bodyErr := resp.Body.Close(); bodyErr != nil {
logrus.WithError(bodyErr).Error("Error closing response body.")
logger.ErrorContext(ctx, "Error closing response body", "error", bodyErr)
}
}
// Handle makeHTTPGetReq errors.
Expand Down
2 changes: 0 additions & 2 deletions lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import (
"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
oteltrace "go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -75,7 +74,6 @@ import (
)

var (
log = logrus.WithField(teleport.ComponentKey, teleport.ComponentAuth)
logger = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentAuth)
)

Expand Down
10 changes: 7 additions & 3 deletions lib/auth/join_ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,12 @@ func (a *Server) tryToDetectIdentityReuse(ctx context.Context, req *types.Regist
return trace.Wrap(err)
}
if instanceExists {
log.Warnf("Server with ID %q and role %q is attempting to join the cluster with a Simplified Node Joining request, but"+
" a server with this ID is already present in the cluster.", req.HostID, req.Role)
const msg = "Server is attempting to join the cluster with a Simplified Node Joining request, but" +
" a server with this ID is already present in the cluster"
a.logger.WarnContext(ctx, msg,
"host_id", req.HostID,
"role", req.Role,
)
return trace.AccessDenied("server with host ID %q and role %q already exists", req.HostID, req.Role)
}
return nil
Expand All @@ -363,7 +367,7 @@ func (a *Server) checkEC2JoinRequest(ctx context.Context, req *types.RegisterUsi
return trace.Wrap(err)
}

log.Debugf("Received Simplified Node Joining request for host %q", req.HostID)
a.logger.DebugContext(ctx, "Received Simplified Node Joining request", "host_id", req.HostID)

if len(req.EC2IdentityDocument) == 0 {
return trace.AccessDenied("this token is only valid for the EC2 join " +
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func validateSTSIdentityRequest(req *http.Request, challenge string, cfg *iamReg
// invalid sts:GetCallerIdentity request, it's either going to be caused
// by a node in a unknown region or an attacker.
if err != nil {
log.WithError(err).Warn("Detected an invalid sts:GetCallerIdentity used by a client attempting to use the IAM join method.")
logger.WarnContext(req.Context(), "Detected an invalid sts:GetCallerIdentity used by a client attempting to use the IAM join method", "error", err)
}
}()

Expand Down
2 changes: 1 addition & 1 deletion lib/auth/kube.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (a *Server) ProcessKubeCSR(req authclient.KubeCSR) (*authclient.KubeCSRResp

// Certificate for remote cluster is a user certificate
// with special provisions.
log.Debugf("Generating certificate to access remote Kubernetes clusters.")
a.logger.DebugContext(ctx, "Generating certificate to access remote Kubernetes clusters")

hostCA, err := a.GetCertAuthority(ctx, types.CertAuthID{
Type: types.HostCA,
Expand Down
41 changes: 23 additions & 18 deletions lib/auth/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"bytes"
"context"
"errors"
"log/slog"
"net"
"time"

Expand Down Expand Up @@ -71,16 +70,19 @@ func (a *Server) authenticateUserLogin(ctx context.Context, req authclient.Authe
clientMetadata: req.ClientMetadata,
authErr: err,
}); err != nil {
log.WithError(err).Warn("Failed to emit login event")
a.logger.WarnContext(ctx, "Failed to emit login event", "error", err)
}
return nil, nil, trace.Wrap(err)
}

switch {
case username != "" && actualUsername != "" && username != actualUsername:
log.Warnf("Authenticate user mismatch (%q vs %q). Using request user (%q)", username, actualUsername, username)
a.logger.WarnContext(ctx, "Authenticate user mismatch, using request user",
"username", username,
"request_user", actualUsername,
)
case username == "" && actualUsername != "":
log.Debugf("User %q authenticated via passwordless", actualUsername)
a.logger.DebugContext(ctx, "User authenticated via passwordless", "username", actualUsername)
username = actualUsername
}

Expand Down Expand Up @@ -123,7 +125,7 @@ func (a *Server) authenticateUserLogin(ctx context.Context, req authclient.Authe
checker: checker,
authErr: err,
}); err != nil {
log.WithError(err).Warn("Failed to emit login event")
a.logger.WarnContext(ctx, "Failed to emit login event", "error", err)
}
return nil, nil, trace.Wrap(err)
}
Expand All @@ -135,7 +137,7 @@ func (a *Server) authenticateUserLogin(ctx context.Context, req authclient.Authe
mfaDevice: mfaDev,
checker: checker,
}); err != nil {
log.WithError(err).Warn("Failed to emit login event")
a.logger.WarnContext(ctx, "Failed to emit login event", "error", err)
}

return userState, checker, trace.Wrap(err)
Expand Down Expand Up @@ -303,7 +305,7 @@ func (a *Server) authenticateUserInternal(
if req.HeadlessAuthenticationID != "" {
mfaDev, err = a.authenticateHeadless(ctx, req)
if err != nil {
slog.DebugContext(ctx, "Headless authenticate failed while waiting for approval",
a.logger.DebugContext(ctx, "Headless authenticate failed while waiting for approval",
"user", user,
"error", err,
)
Expand All @@ -330,7 +332,7 @@ func (a *Server) authenticateUserInternal(
case err != nil:
return nil, "", trace.Wrap(err)
case u.GetUserType() != types.UserTypeLocal:
slog.WarnContext(ctx, "Non-local user attempted local authentication",
a.logger.WarnContext(ctx, "Non-local user attempted local authentication",
"user", user,
"user_type", u.GetUserType(),
)
Expand Down Expand Up @@ -381,7 +383,7 @@ func (a *Server) authenticateUserInternal(
})
switch {
case err != nil:
slog.DebugContext(ctx, "User failed to authenticate.",
a.logger.DebugContext(ctx, "User failed to authenticate",
"user", user,
"error", err,
)
Expand All @@ -391,7 +393,7 @@ func (a *Server) authenticateUserInternal(

return nil, "", trace.Wrap(authErr)
case mfaDev == nil:
slog.DebugContext(ctx, "MFA authentication returned nil device.",
a.logger.DebugContext(ctx, "MFA authentication returned nil device",
"webauthn", req.Webauthn != nil,
"totp", req.OTP != nil,
"headless", req.HeadlessAuthenticationID != "",
Expand Down Expand Up @@ -420,7 +422,7 @@ func (a *Server) authenticateUserInternal(
// Some form of MFA is required but none provided. Either client is
// buggy (didn't send MFA response) or someone is trying to bypass
// MFA.
slog.WarnContext(ctx, "MFA bypass attempt, access denied.", "user", user)
a.logger.WarnContext(ctx, "MFA bypass attempt, access denied", "user", user)
return nil, "", trace.AccessDenied("missing second factor")
case authPreference.IsSecondFactorEnabled():
// 2FA is optional. Make sure that a user does not have MFA devices
Expand All @@ -430,7 +432,7 @@ func (a *Server) authenticateUserInternal(
return nil, "", trace.Wrap(err)
}
if len(devs) != 0 {
slog.WarnContext(ctx, "MFA bypass attempt, access denied.", "user", user)
a.logger.WarnContext(ctx, "MFA bypass attempt, access denied", "user", user)
return nil, "", trace.AccessDenied("missing second factor authentication")
}
default:
Expand All @@ -444,7 +446,7 @@ func (a *Server) authenticateUserInternal(
}
// provide obscure message on purpose, while logging the real
// error server side
slog.DebugContext(ctx, "User failed to authenticate.",
a.logger.DebugContext(ctx, "User failed to authenticate",
"user", user,
"error", err,
)
Expand All @@ -467,15 +469,18 @@ func (a *Server) authenticatePasswordless(ctx context.Context, req authclient.Au
case errors.Is(err, types.ErrPassswordlessLoginBySSOUser):
return nil, "", trace.Wrap(err)
case err != nil:
log.Debugf("Passwordless authentication failed: %v", err)
a.logger.DebugContext(ctx, "Passwordless authentication failed", "error", err)
return nil, "", trace.Wrap(authenticateWebauthnError)
}

// A distinction between passwordless and "plain" MFA is that we can't
// acquire the user lock beforehand (or at all on failures!)
// We do grab it here so successful logins go through the regular process.
if err := a.WithUserLock(ctx, mfaData.User, func() error { return nil }); err != nil {
log.Debugf("WithUserLock for user %q failed during passwordless authentication: %v", mfaData.User, err)
a.logger.DebugContext(ctx, "WithUserLock failed during passwordless authentication",
"user", mfaData.User,
"error", err,
)
return nil, mfaData.User, trace.Wrap(authenticateWebauthnError)
}

Expand All @@ -487,7 +492,7 @@ func (a *Server) authenticateHeadless(ctx context.Context, req authclient.Authen
defer func() {
if err != nil {
if err := a.DeleteHeadlessAuthentication(a.CloseContext(), req.Username, req.HeadlessAuthenticationID); err != nil && !trace.IsNotFound(err) {
log.Debugf("Failed to delete headless authentication: %v", err)
a.logger.DebugContext(ctx, "Failed to delete headless authentication", "error", err)
}
}
}()
Expand Down Expand Up @@ -773,7 +778,7 @@ func (a *Server) emitNoLocalAuthEvent(username string) {
Error: noLocalAuth,
},
}); err != nil {
log.WithError(err).Warn("Failed to emit no local auth event.")
a.logger.WarnContext(a.closeCtx, "Failed to emit no local auth event", "error", err)
}
}

Expand All @@ -794,7 +799,7 @@ func getErrorByTraceField(err error) error {
ok := errors.As(err, &traceErr)
switch {
case !ok:
log.WithError(err).Warn("Unexpected error type, wanted TraceError")
logger.WarnContext(context.Background(), "Unexpected error type, wanted TraceError", "error", err)
return trace.AccessDenied("an error has occurred")
case traceErr.GetFields()[ErrFieldKeyUserMaxedAttempts] != nil:
return trace.AccessDenied(MaxFailedAttemptsErrMsg)
Expand Down
Loading
Loading