Skip to content

Commit

Permalink
*: fix security issues
Browse files Browse the repository at this point in the history
Signed-off-by: mornyx <[email protected]>
  • Loading branch information
mornyx committed Sep 18, 2024
1 parent 3d0c3db commit 3305937
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 10 deletions.
2 changes: 2 additions & 0 deletions cmd/tidb-dashboard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ func NewCLIConfig() *DashboardCLIConfig {
flag.IntVar(&cfg.CoreConfig.NgmTimeout, "ngm-timeout", cfg.CoreConfig.NgmTimeout, "timeout secs for accessing the ngm API")
flag.BoolVar(&cfg.CoreConfig.EnableKeyVisualizer, "keyviz", true, "enable/disable key visualizer(default: true)")
flag.BoolVar(&cfg.CoreConfig.DisableCustomPromAddr, "disable-custom-prom-addr", false, "do not allow custom prometheus address")
flag.Float64Var(&cfg.CoreConfig.UnauthedAPIQpsLimit, "unauthed-api-qps-limit", cfg.CoreConfig.UnauthedAPIQpsLimit, "unauthed API qps limit")
flag.IntVar(&cfg.CoreConfig.UnauthedAPIBurstLimit, "unauthed-api-burst-limit", cfg.CoreConfig.UnauthedAPIBurstLimit, "unauthed API burst limit")

showVersion := flag.BoolP("version", "v", false, "print version information and exit")

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ require (
go.uber.org/zap v1.19.0
golang.org/x/oauth2 v0.11.0
golang.org/x/sync v0.3.0
golang.org/x/time v0.6.0
google.golang.org/grpc v1.59.0
google.golang.org/protobuf v1.33.0
gorm.io/datatypes v1.1.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
167 changes: 166 additions & 1 deletion pkg/apiserver/logsearch/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/gin-gonic/gin"
"github.com/pingcap/log"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/fx"
"go.uber.org/zap"

Expand All @@ -19,6 +20,8 @@ import (
"github.com/pingcap/tidb-dashboard/pkg/apiserver/utils"
"github.com/pingcap/tidb-dashboard/pkg/config"
"github.com/pingcap/tidb-dashboard/pkg/dbstore"
"github.com/pingcap/tidb-dashboard/pkg/pd"
"github.com/pingcap/tidb-dashboard/pkg/utils/topology"
"github.com/pingcap/tidb-dashboard/util/rest"
)

Expand All @@ -30,9 +33,11 @@ type Service struct {
logStoreDirectory string
db *dbstore.DB
scheduler *Scheduler
etcdClient *clientv3.Client
pdClient *pd.Client
}

func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB) *Service {
func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB, etcdClient *clientv3.Client, pdClient *pd.Client) *Service {
dir := config.TempDir
if dir == "" {
var err error
Expand All @@ -52,6 +57,8 @@ func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB) *Service
logStoreDirectory: dir,
db: db,
scheduler: nil, // will be filled after scheduler is created
etcdClient: etcdClient,
pdClient: pdClient,
}
scheduler := NewScheduler(service)
service.scheduler = scheduler
Expand Down Expand Up @@ -112,6 +119,10 @@ func (s *Service) CreateTaskGroup(c *gin.Context) {
rest.Error(c, rest.ErrBadRequest.New("Expect at least 1 target"))
return
}
if err := s.verifyTargets(c.Request.Context(), req.Targets); err != nil {
rest.Error(c, err)
return
}
stats := model.NewRequestTargetStatisticsFromArray(&req.Targets)
taskGroup := TaskGroupModel{
SearchRequest: &req.Request,
Expand Down Expand Up @@ -361,3 +372,157 @@ func (s *Service) DownloadLogs(c *gin.Context) {
serveMultipleTaskForDownload(tasks, c)
}
}

func (s *Service) verifyTargets(ctx context.Context, targets []model.RequestTargetNode) error {
kindToTargets := make(map[model.NodeKind][]model.RequestTargetNode)
for _, target := range targets {
kindToTargets[target.Kind] = append(kindToTargets[target.Kind], target)
}
var tikvInfos []topology.StoreInfo
var tiflashInfos []topology.StoreInfo
for kind, targets := range kindToTargets {
switch kind {
case model.NodeKindTiDB:
infos, err := topology.FetchTiDBTopology(ctx, s.etcdClient)
if err != nil {
log.Error("failed to fetch tidb topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindTiKV, model.NodeKindTiFlash:
if len(tikvInfos) == 0 {
var err error
tikvInfos, tiflashInfos, err = topology.FetchStoreTopology(s.pdClient)
if err != nil {
log.Error("failed to fetch tikv/tiflash topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
}
for _, target := range targets {
matched := false
if kind == model.NodeKindTiKV {
for _, info := range tikvInfos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
} else {
for _, info := range tiflashInfos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindPD:
infos, err := topology.FetchPDTopology(s.pdClient)
if err != nil {
log.Error("failed to fetch pd topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindTiCDC:
infos, err := topology.FetchTiCDCTopology(ctx, s.etcdClient)
if err != nil {
log.Error("failed to fetch ticdc topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindTiProxy:
infos, err := topology.FetchTiProxyTopology(ctx, s.etcdClient)
if err != nil {
log.Error("failed to fetch tiproxy topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindTSO:
infos, err := topology.FetchTSOTopology(ctx, s.pdClient)
if err != nil {
log.Error("failed to fetch tso topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindScheduling:
infos, err := topology.FetchSchedulingTopology(ctx, s.pdClient)
if err != nil {
log.Error("failed to fetch scheduling topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
default:
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
return nil
}
3 changes: 1 addition & 2 deletions pkg/apiserver/logsearch/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"strconv"
"sync"
"time"
"unsafe"

"github.com/pingcap/kvproto/pkg/diagnosticspb"
"github.com/pingcap/log"
Expand Down Expand Up @@ -252,7 +251,7 @@ func (t *Task) searchLog(client diagnosticspb.DiagnosticsClient, targetType diag
}
for _, msg := range res.Messages {
line := logMessageToString(msg)
_, err := bufWriter.Write(*(*[]byte)(unsafe.Pointer(&line))) // #nosec
_, err := bufWriter.WriteString(line)
if err != nil {
t.setError(err)
return
Expand Down
2 changes: 1 addition & 1 deletion pkg/apiserver/statement/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (s *Service) createPlanBinding(db *gorm.DB, planDigest string) (err error)
return errors.New("invalid planDigest")
}

query := db.Exec(fmt.Sprintf("CREATE GLOBAL BINDING FROM HISTORY USING PLAN DIGEST '%s'", planDigest))
query := db.Exec("CREATE GLOBAL BINDING FROM HISTORY USING PLAN DIGEST ?", planDigest)
return query.Error
}

Expand Down
22 changes: 20 additions & 2 deletions pkg/apiserver/user/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import (
"github.com/joomcode/errorx"
"github.com/pingcap/log"
"go.uber.org/zap"
"golang.org/x/time/rate"

"github.com/pingcap/tidb-dashboard/pkg/apiserver/utils"
"github.com/pingcap/tidb-dashboard/pkg/config"
"github.com/pingcap/tidb-dashboard/util/featureflag"
"github.com/pingcap/tidb-dashboard/util/rest"
)
Expand Down Expand Up @@ -244,10 +246,14 @@ func (s *AuthService) authForm(f AuthenticateForm) (*utils.SessionUser, error) {
return u, nil
}

func registerRouter(r *gin.RouterGroup, s *AuthService) {
func registerRouter(r *gin.RouterGroup, s *AuthService, cfg *config.Config) {
endpoint := r.Group("/user")
endpoint.GET("/login_info", s.GetLoginInfoHandler)
endpoint.POST("/login", s.LoginHandler)
if cfg.UnauthedAPIQpsLimit > 0 && cfg.UnauthedAPIBurstLimit > 0 {
endpoint.POST("/login", s.MWRateLimited(rate.Limit(cfg.UnauthedAPIQpsLimit), cfg.UnauthedAPIBurstLimit), s.LoginHandler)
} else {
endpoint.POST("/login", s.LoginHandler)
}
endpoint.GET("/sign_out_info", s.MWAuthRequired(), s.getSignOutInfoHandler)
}

Expand Down Expand Up @@ -293,6 +299,18 @@ func (s *AuthService) MWRequireWritePriv() gin.HandlerFunc {
}
}

func (s *AuthService) MWRateLimited(r rate.Limit, b int) gin.HandlerFunc {
limiter := rate.NewLimiter(r, b)
return func(ctx *gin.Context) {
if !limiter.Allow() {
rest.Error(ctx, rest.ErrTooManyRequests.NewWithNoMessage())
ctx.Abort()
return
}
ctx.Next()
}
}

// RegisterAuthenticator registers an authenticator in the authenticate pipeline.
func (s *AuthService) RegisterAuthenticator(typeID utils.AuthType, a Authenticator) {
s.authenticators[typeID] = a
Expand Down
5 changes: 5 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ type Config struct {
FeatureVersion string // assign the target TiDB version when running TiDB Dashboard as standalone mode

NgmTimeout int // in seconds

UnauthedAPIQpsLimit float64
UnauthedAPIBurstLimit int
}

func Default() *Config {
Expand All @@ -54,6 +57,8 @@ func Default() *Config {
DisableCustomPromAddr: false,
FeatureVersion: version.PDVersion,
NgmTimeout: 30, // s
UnauthedAPIQpsLimit: 0,
UnauthedAPIBurstLimit: 0,
}
}

Expand Down
11 changes: 7 additions & 4 deletions util/rest/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ import (
)

var (
ErrUnauthenticated = errorx.CommonErrors.NewType("unauthenticated")
ErrForbidden = errorx.CommonErrors.NewType("forbidden")
ErrBadRequest = errorx.CommonErrors.NewType("bad_request")
ErrNotFound = errorx.CommonErrors.NewType("not_found")
ErrUnauthenticated = errorx.CommonErrors.NewType("unauthenticated")
ErrForbidden = errorx.CommonErrors.NewType("forbidden")
ErrBadRequest = errorx.CommonErrors.NewType("bad_request")
ErrNotFound = errorx.CommonErrors.NewType("not_found")
ErrTooManyRequests = errorx.CommonErrors.NewType("too_many_requests")
ErrInvalidEndpoint = errorx.CommonErrors.NewType("invalid_endpoint")
ErrInternalServerError = errorx.CommonErrors.NewType("internal_server_error")

errInternal = errorx.CommonErrors.NewType("internal")
propHTTPCode = errorx.RegisterProperty("http_code")
Expand Down

0 comments on commit 3305937

Please sign in to comment.