diff --git a/internal/pkg/ectx/ctx.go b/internal/pkg/ectx/ctx.go new file mode 100644 index 00000000..ac2bad5c --- /dev/null +++ b/internal/pkg/ectx/ctx.go @@ -0,0 +1,21 @@ +package ectx + +import "context" + +type appContextType string + +var ( + appCtxKey appContextType = "app" +) +func GetAppIdFromCtx(ctx context.Context) (uint, bool) { + app := ctx.Value(appCtxKey) + if app == nil { + return 0, false + } + v, ok := app.(uint) + return v, ok +} + +func CtxWithAppId(ctx context.Context, appid uint) context.Context { + return context.WithValue(ctx, appCtxKey, appid) +} \ No newline at end of file diff --git a/internal/pkg/middleware/add_appid_builder.go b/internal/pkg/middleware/add_appid_builder.go deleted file mode 100644 index bf1c9ef9..00000000 --- a/internal/pkg/middleware/add_appid_builder.go +++ /dev/null @@ -1,54 +0,0 @@ -package middleware - -import ( - "context" - "net/http" - "strconv" - - "github.com/ecodeclub/ginx" - "github.com/gin-gonic/gin" - "github.com/gotomicro/ego/core/elog" -) - -type AddAppIdBuilder struct { -} - -type AppContextType string - -const ( - AppCtxKey AppContextType = "app" -) - -func NewAddAppIdBuilder() *AddAppIdBuilder { - return &AddAppIdBuilder{} -} -func (a *AddAppIdBuilder) Build() gin.HandlerFunc { - return func(ctx *gin.Context) { - gctx := &ginx.Context{Context: ctx} - appid := ctx.GetHeader(string(AppCtxKey)) - if appid != "" { - c := ctx.Request.Context() - app, err := strconv.Atoi(appid) - if err != nil { - gctx.AbortWithStatus(http.StatusBadRequest) - elog.Error("appid设置失败", elog.FieldErr(err)) - return - } - newCtx := CtxWithAppId(c, uint(app)) - ctx.Request = ctx.Request.WithContext(newCtx) - } - } -} - -func AppID(ctx context.Context) (uint, bool) { - app := ctx.Value(AppCtxKey) - if app == nil { - return 0, false - } - v, ok := app.(uint) - return v, ok -} - -func CtxWithAppId(ctx context.Context, appid uint) context.Context { - return context.WithValue(ctx, AppCtxKey, appid) -} diff --git a/internal/pkg/middleware/check_appid_builder.go b/internal/pkg/middleware/check_appid_builder.go new file mode 100644 index 00000000..aab35041 --- /dev/null +++ b/internal/pkg/middleware/check_appid_builder.go @@ -0,0 +1,43 @@ +package middleware + +import ( + "github.com/ecodeclub/webook/internal/pkg/ectx" + "net/http" + "strconv" + + "github.com/ecodeclub/ginx" + "github.com/gin-gonic/gin" + "github.com/gotomicro/ego/core/elog" +) + +type CheckAppIdBuilder struct { +} + + +const ( + appIDHeader = "app" +) + +func NewCheckAppIdBuilder() *CheckAppIdBuilder { + return &CheckAppIdBuilder{} +} +func (a *CheckAppIdBuilder) Build() gin.HandlerFunc { + return func(ctx *gin.Context) { + gctx := &ginx.Context{Context: ctx} + appid := ctx.GetHeader(appIDHeader) + if appid == "" { + return + } + c := ctx.Request.Context() + app, err := strconv.Atoi(appid) + if err != nil { + gctx.AbortWithStatus(http.StatusBadRequest) + elog.Error("appid设置失败", elog.FieldErr(err)) + return + } + newCtx := ectx.CtxWithAppId(c, uint(app)) + ctx.Request = ctx.Request.WithContext(newCtx) + } +} + + diff --git a/internal/pkg/middleware/add_appid_builder_test.go b/internal/pkg/middleware/check_appid_builder_test.go similarity index 87% rename from internal/pkg/middleware/add_appid_builder_test.go rename to internal/pkg/middleware/check_appid_builder_test.go index bff6f0ae..245df81b 100644 --- a/internal/pkg/middleware/add_appid_builder_test.go +++ b/internal/pkg/middleware/check_appid_builder_test.go @@ -1,6 +1,7 @@ package middleware import ( + "github.com/ecodeclub/webook/internal/pkg/ectx" "net/http" "net/http/httptest" "testing" @@ -22,14 +23,13 @@ func TestAddAppId(t *testing.T) { wantCode: 200, before: func(t *testing.T, ctx *gin.Context) { header := make(http.Header) - header.Set(string(AppCtxKey), "1") + header.Set(appIDHeader, "1") ctx.Request = httptest.NewRequest(http.MethodPost, "/users/profile", nil) ctx.Request.Header = header }, afterFunc: func(t *testing.T, ctx *gin.Context) { c := ctx.Request.Context() - v := c.Value(AppCtxKey) - res, ok := v.(uint) + res, ok := ectx.GetAppIdFromCtx(c) require.True(t, ok) assert.Equal(t, uint(1), res) }, @@ -44,8 +44,8 @@ func TestAddAppId(t *testing.T) { }, afterFunc: func(t *testing.T, ctx *gin.Context) { c := ctx.Request.Context() - v := c.Value(AppCtxKey) - require.Nil(t, v) + _, ok := ectx.GetAppIdFromCtx(c) + require.False(t, ok) }, }, { @@ -53,7 +53,7 @@ func TestAddAppId(t *testing.T) { wantCode: 400, before: func(t *testing.T, ctx *gin.Context) { header := make(http.Header) - header.Set(string(AppCtxKey), "dasdsa") + header.Set(appIDHeader, "dasdsa") ctx.Request = httptest.NewRequest(http.MethodPost, "/users/profile", nil) ctx.Request.Header = header }, @@ -67,7 +67,7 @@ func TestAddAppId(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) tc.before(t, c) - builder := NewAddAppIdBuilder() + builder := NewCheckAppIdBuilder() hdl := builder.Build() hdl(c) assert.Equal(t, tc.wantCode, c.Writer.Status()) diff --git a/internal/pkg/snowflake/snowflake.go b/internal/pkg/snowflake/snowflake.go index f8fc16c9..77a75ca2 100644 --- a/internal/pkg/snowflake/snowflake.go +++ b/internal/pkg/snowflake/snowflake.go @@ -8,37 +8,49 @@ import ( "github.com/ecodeclub/ekit/syncx" ) -type SnowFlake interface { +type ID int64 + +// AppID 返回生成时输入的appid +func (f ID) AppID() uint { + node := snowflake.ID(f).Node() + return uint(node >> 5) +} + +func (f ID) Int64() int64 { + return int64(f) +} + +type AppIDGenerator interface { + // Generate 功能生成雪花id(ID)。返回雪花id(ID)的每一位的组成如下。返回值ID可以通过AppID()返回生成时输入的appid。 + // +---------------------------------------------------------------------------------------+ + // | 1 Bit Unused | 41 Bit Timestamp | 5 Bit APPID | 5 Bit NodeID | 12 Bit Sequence ID | + // +---------------------------------------------------------------------------------------+ Generate(appid uint) (ID, error) } -type CustomSnowFlake struct { +type MeoyingIDGenerator struct { // 键为appid nodes *syncx.Map[uint, *snowflake.Node] } const ( - maxNode uint = 31 - maxApp uint = 31 + maxNodeNum uint = 31 + maxAppNum uint = 31 ) var ( - ErrExceedNode = errors.New("node超出限制") - ErrExceedApp = errors.New("app超出限制") + ErrExceedNode = errors.New("node编号超出限制") + ErrExceedApp = errors.New("app编号超出限制") ErrUnknownApp = errors.New("未知的app") ) -// +---------------------------------------------------------------------------------------+ -// | 1 Bit Unused | 41 Bit Timestamp | 5 Bit APPID | 5 Bit NodeID | 12 Bit Sequence ID | -// +---------------------------------------------------------------------------------------+ - -// node表示第几个节点,appid表示有几个应用 从0开始排序 0-ietls 最多到31 -func NewCustomSnowFlake(nodeId uint, apps uint) (*CustomSnowFlake, error) { +// NewMeoyingIDGenerator nodeId表示第几个节点,apps表示有几个应用 从0开始排序 0-webook 1-ielts 最多到31 +func NewMeoyingIDGenerator(nodeId uint, apps uint) (*MeoyingIDGenerator, error) { nodeMap := &syncx.Map[uint, *snowflake.Node]{} - if nodeId > maxNode { + if nodeId > maxNodeNum { return nil, fmt.Errorf("%w", ErrExceedNode) } - if apps > maxApp+1 { + if apps > maxAppNum+1 { return nil, fmt.Errorf("%w", ErrExceedApp) } for i := 0; i < int(apps); i++ { @@ -49,15 +61,13 @@ func NewCustomSnowFlake(nodeId uint, apps uint) (*CustomSnowFlake, error) { } nodeMap.Store(uint(i), n) } - return &CustomSnowFlake{ + return &MeoyingIDGenerator{ nodes: nodeMap, }, nil } -type ID int64 - -func (c *CustomSnowFlake) Generate(appid uint) (ID, error) { +func (c *MeoyingIDGenerator) Generate(appid uint) (ID, error) { n, ok := c.nodes.Load(appid) if !ok { return 0, fmt.Errorf("%w", ErrUnknownApp) @@ -65,12 +75,3 @@ func (c *CustomSnowFlake) Generate(appid uint) (ID, error) { id := n.Generate() return ID(id), nil } - -func (f ID) AppID() uint { - node := snowflake.ID(f).Node() - return uint(node >> 5) -} - -func (f ID) Int64() int64 { - return int64(f) -} diff --git a/internal/pkg/snowflake/snowflake_test.go b/internal/pkg/snowflake/snowflake_test.go index 9288b1bc..7e196629 100644 --- a/internal/pkg/snowflake/snowflake_test.go +++ b/internal/pkg/snowflake/snowflake_test.go @@ -3,37 +3,51 @@ package snowflake import ( "testing" - "github.com/bwmarrin/snowflake" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_NewGenerate(t *testing.T) { - t.Run("nodeId超出限制", func(t *testing.T) { - _, err := NewCustomSnowFlake(32, 6) - require.ErrorIs(t, err, ErrExceedNode) - }) - t.Run("app数量超出限制", func(t *testing.T) { - _, err := NewCustomSnowFlake(3, 33) - require.ErrorIs(t, err, ErrExceedApp) - }) - t.Run("正常生成", func(t *testing.T) { - idMaker, err := NewCustomSnowFlake(0, 4) - require.NoError(t, err) - appids := make([]uint, 0, 4) - idMaker.nodes.Range(func(key uint, value *snowflake.Node) bool { - appids = append(appids, key) - return true + testcases := []struct { + name string + nodeId uint + apps uint + wantErrFunc require.ErrorAssertionFunc + }{ + { + name: "nodeId超出限制", + nodeId: 32, + apps: 6, + wantErrFunc: func(t require.TestingT, err error, _ ...interface{}) { + require.ErrorIs(t, err, ErrExceedNode) + }, + }, + { + name: "appId超出限制", + nodeId: 3, + apps: 33, + wantErrFunc: func(t require.TestingT, err error, _ ...interface{}) { + require.ErrorIs(t, err, ErrExceedApp) + }, + }, + { + name: "生成正常", + nodeId: 0, + apps: 6, + wantErrFunc: require.NoError, + }, + } + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + _, err := NewMeoyingIDGenerator(tt.nodeId, tt.apps) + tt.wantErrFunc(t, err) }) - assert.ElementsMatch(t, []uint{ - 0, 1, 2, 3, - }, appids) - }) + } } func Test_Generate(t *testing.T) { - idmaker, err := NewCustomSnowFlake(1, 6) + idmaker, err := NewMeoyingIDGenerator(1, 6) require.NoError(t, err) ids := make([]int64, 0) for i := 0; i < 6; i++ { @@ -54,7 +68,7 @@ func Test_Generate(t *testing.T) { } func Test_GenerateAppId(t *testing.T) { - idmaker, err := NewCustomSnowFlake(1, 16) + idmaker, err := NewMeoyingIDGenerator(1, 16) require.NoError(t, err) testcases := []struct { name string diff --git a/internal/user/internal/integration/handler_test.go b/internal/user/internal/integration/handler_test.go index c8a31481..073e8079 100644 --- a/internal/user/internal/integration/handler_test.go +++ b/internal/user/internal/integration/handler_test.go @@ -20,11 +20,10 @@ import ( "context" "database/sql" "fmt" + "github.com/ecodeclub/webook/internal/pkg/ectx" "net/http" "testing" - "github.com/ecodeclub/webook/internal/pkg/middleware" - "github.com/ecodeclub/webook/internal/pkg/snowflake" "github.com/ecodeclub/webook/internal/user" @@ -1162,7 +1161,7 @@ func (s *HandlerWithAppTestSuite) TestFindOrCreateByWechat() { }, before: func(t *testing.T) {}, ctx: func() context.Context { - return middleware.CtxWithAppId(context.Background(), uint(1)) + return ectx.CtxWithAppId(context.Background(), uint(1)) }, after: func(t *testing.T) { var u dao.User diff --git a/internal/user/internal/repository/dao/user_callback.go b/internal/user/internal/repository/dao/user_callback.go index c44f0695..2f66968e 100644 --- a/internal/user/internal/repository/dao/user_callback.go +++ b/internal/user/internal/repository/dao/user_callback.go @@ -3,8 +3,7 @@ package dao import ( "context" "fmt" - - "github.com/ecodeclub/webook/internal/pkg/middleware" + "github.com/ecodeclub/webook/internal/pkg/ectx" "github.com/ecodeclub/webook/internal/pkg/snowflake" "github.com/gotomicro/ego/core/elog" @@ -27,11 +26,11 @@ var ( type UserInsertCallBackBuilder struct { logger *elog.Component - idMaker *snowflake.CustomSnowFlake + idMaker snowflake.AppIDGenerator } func NewUserInsertCallBackBuilder(nodeid, apps uint) (*UserInsertCallBackBuilder, error) { - idMaker, err := snowflake.NewCustomSnowFlake(nodeid, apps) + idMaker, err := snowflake.NewMeoyingIDGenerator(nodeid, apps) if err != nil { return nil, err } @@ -137,7 +136,7 @@ func userId(ctx context.Context) (int64, bool) { } func appId(ctx context.Context) (uint, bool) { - return middleware.AppID(ctx) + return ectx.GetAppIdFromCtx(ctx) } func tableNameFromAppId(appid uint) (string, error) { diff --git a/internal/user/internal/web/handler.go b/internal/user/internal/web/handler.go index 51770c30..712eebb5 100644 --- a/internal/user/internal/web/handler.go +++ b/internal/user/internal/web/handler.go @@ -69,12 +69,12 @@ func NewHandler( func (h *Handler) PrivateRoutes(server *gin.Engine) { users := server.Group("/users") - users.GET("/profile", middleware.NewAddAppIdBuilder().Build(), ginx.S(h.Profile)) - users.POST("/profile", middleware.NewAddAppIdBuilder().Build(), ginx.BS[EditReq](h.Edit)) + users.GET("/profile", middleware.NewCheckAppIdBuilder().Build(), ginx.S(h.Profile)) + users.POST("/profile", middleware.NewCheckAppIdBuilder().Build(), ginx.BS[EditReq](h.Edit)) } func (h *Handler) PublicRoutes(server *gin.Engine) { - appidFunc := middleware.NewAddAppIdBuilder().Build() + appidFunc := middleware.NewCheckAppIdBuilder().Build() oauth2 := server.Group("/oauth2") oauth2.GET("/wechat/auth_url", appidFunc, ginx.W(h.WechatAuthURL)) oauth2.GET("/mock/login", appidFunc, ginx.W(h.MockLogin))