Skip to content

Commit

Permalink
修改
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwenliang committed Jul 22, 2024
1 parent 67fb086 commit 0db346f
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 121 deletions.
21 changes: 21 additions & 0 deletions internal/pkg/ectx/ctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package ectx

import "context"

type appContextType string

var (
appCtxKey appContextType = "app"
)
func GetAppIdFromCtx(ctx context.Context) (uint, bool) {
app := ctx.Value(appCtxKey)
if app == nil {
return 0, false
}
v, ok := app.(uint)
return v, ok
}

func CtxWithAppId(ctx context.Context, appid uint) context.Context {
return context.WithValue(ctx, appCtxKey, appid)
}
54 changes: 0 additions & 54 deletions internal/pkg/middleware/add_appid_builder.go

This file was deleted.

43 changes: 43 additions & 0 deletions internal/pkg/middleware/check_appid_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package middleware

import (
"github.com/ecodeclub/webook/internal/pkg/ectx"
"net/http"
"strconv"

"github.com/ecodeclub/ginx"
"github.com/gin-gonic/gin"
"github.com/gotomicro/ego/core/elog"
)

type CheckAppIdBuilder struct {
}


const (
appIDHeader = "app"
)

func NewCheckAppIdBuilder() *CheckAppIdBuilder {
return &CheckAppIdBuilder{}
}
func (a *CheckAppIdBuilder) Build() gin.HandlerFunc {
return func(ctx *gin.Context) {
gctx := &ginx.Context{Context: ctx}
appid := ctx.GetHeader(appIDHeader)
if appid == "" {
return
}
c := ctx.Request.Context()
app, err := strconv.Atoi(appid)
if err != nil {
gctx.AbortWithStatus(http.StatusBadRequest)
elog.Error("appid设置失败", elog.FieldErr(err))
return
}
newCtx := ectx.CtxWithAppId(c, uint(app))
ctx.Request = ctx.Request.WithContext(newCtx)
}
}


Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"github.com/ecodeclub/webook/internal/pkg/ectx"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -22,14 +23,13 @@ func TestAddAppId(t *testing.T) {
wantCode: 200,
before: func(t *testing.T, ctx *gin.Context) {
header := make(http.Header)
header.Set(string(AppCtxKey), "1")
header.Set(appIDHeader, "1")
ctx.Request = httptest.NewRequest(http.MethodPost, "/users/profile", nil)
ctx.Request.Header = header
},
afterFunc: func(t *testing.T, ctx *gin.Context) {
c := ctx.Request.Context()
v := c.Value(AppCtxKey)
res, ok := v.(uint)
res, ok := ectx.GetAppIdFromCtx(c)
require.True(t, ok)
assert.Equal(t, uint(1), res)
},
Expand All @@ -44,16 +44,16 @@ func TestAddAppId(t *testing.T) {
},
afterFunc: func(t *testing.T, ctx *gin.Context) {
c := ctx.Request.Context()
v := c.Value(AppCtxKey)
require.Nil(t, v)
_, ok := ectx.GetAppIdFromCtx(c)
require.False(t, ok)
},
},
{
name: "appid 设置为不是数字",
wantCode: 400,
before: func(t *testing.T, ctx *gin.Context) {
header := make(http.Header)
header.Set(string(AppCtxKey), "dasdsa")
header.Set(appIDHeader, "dasdsa")
ctx.Request = httptest.NewRequest(http.MethodPost, "/users/profile", nil)
ctx.Request.Header = header
},
Expand All @@ -67,7 +67,7 @@ func TestAddAppId(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
tc.before(t, c)
builder := NewAddAppIdBuilder()
builder := NewCheckAppIdBuilder()
hdl := builder.Build()
hdl(c)
assert.Equal(t, tc.wantCode, c.Writer.Status())
Expand Down
55 changes: 28 additions & 27 deletions internal/pkg/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,49 @@ import (
"github.com/ecodeclub/ekit/syncx"
)

type SnowFlake interface {
type ID int64

// AppID 返回生成时输入的appid
func (f ID) AppID() uint {
node := snowflake.ID(f).Node()
return uint(node >> 5)
}

func (f ID) Int64() int64 {
return int64(f)
}

type AppIDGenerator interface {
// Generate 功能生成雪花id(ID)。返回雪花id(ID)的每一位的组成如下。返回值ID可以通过AppID()返回生成时输入的appid。
// +---------------------------------------------------------------------------------------+
// | 1 Bit Unused | 41 Bit Timestamp | 5 Bit APPID | 5 Bit NodeID | 12 Bit Sequence ID |
// +---------------------------------------------------------------------------------------+
Generate(appid uint) (ID, error)
}

type CustomSnowFlake struct {
type MeoyingIDGenerator struct {
// 键为appid
nodes *syncx.Map[uint, *snowflake.Node]
}

const (
maxNode uint = 31
maxApp uint = 31
maxNodeNum uint = 31
maxAppNum uint = 31
)

var (
ErrExceedNode = errors.New("node超出限制")
ErrExceedApp = errors.New("app超出限制")
ErrExceedNode = errors.New("node编号超出限制")
ErrExceedApp = errors.New("app编号超出限制")
ErrUnknownApp = errors.New("未知的app")
)

// +---------------------------------------------------------------------------------------+
// | 1 Bit Unused | 41 Bit Timestamp | 5 Bit APPID | 5 Bit NodeID | 12 Bit Sequence ID |
// +---------------------------------------------------------------------------------------+

// node表示第几个节点,appid表示有几个应用 从0开始排序 0-ietls 最多到31
func NewCustomSnowFlake(nodeId uint, apps uint) (*CustomSnowFlake, error) {
// NewMeoyingIDGenerator nodeId表示第几个节点,apps表示有几个应用 从0开始排序 0-webook 1-ielts 最多到31
func NewMeoyingIDGenerator(nodeId uint, apps uint) (*MeoyingIDGenerator, error) {
nodeMap := &syncx.Map[uint, *snowflake.Node]{}
if nodeId > maxNode {
if nodeId > maxNodeNum {
return nil, fmt.Errorf("%w", ErrExceedNode)
}
if apps > maxApp+1 {
if apps > maxAppNum+1 {
return nil, fmt.Errorf("%w", ErrExceedApp)
}
for i := 0; i < int(apps); i++ {
Expand All @@ -49,28 +61,17 @@ func NewCustomSnowFlake(nodeId uint, apps uint) (*CustomSnowFlake, error) {
}
nodeMap.Store(uint(i), n)
}
return &CustomSnowFlake{
return &MeoyingIDGenerator{
nodes: nodeMap,
}, nil

}

type ID int64

func (c *CustomSnowFlake) Generate(appid uint) (ID, error) {
func (c *MeoyingIDGenerator) Generate(appid uint) (ID, error) {
n, ok := c.nodes.Load(appid)
if !ok {
return 0, fmt.Errorf("%w", ErrUnknownApp)
}
id := n.Generate()
return ID(id), nil
}

func (f ID) AppID() uint {
node := snowflake.ID(f).Node()
return uint(node >> 5)
}

func (f ID) Int64() int64 {
return int64(f)
}
58 changes: 36 additions & 22 deletions internal/pkg/snowflake/snowflake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,51 @@ package snowflake
import (
"testing"

"github.com/bwmarrin/snowflake"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_NewGenerate(t *testing.T) {
t.Run("nodeId超出限制", func(t *testing.T) {
_, err := NewCustomSnowFlake(32, 6)
require.ErrorIs(t, err, ErrExceedNode)
})
t.Run("app数量超出限制", func(t *testing.T) {
_, err := NewCustomSnowFlake(3, 33)
require.ErrorIs(t, err, ErrExceedApp)
})
t.Run("正常生成", func(t *testing.T) {
idMaker, err := NewCustomSnowFlake(0, 4)
require.NoError(t, err)
appids := make([]uint, 0, 4)
idMaker.nodes.Range(func(key uint, value *snowflake.Node) bool {
appids = append(appids, key)
return true
testcases := []struct {
name string
nodeId uint
apps uint
wantErrFunc require.ErrorAssertionFunc
}{
{
name: "nodeId超出限制",
nodeId: 32,
apps: 6,
wantErrFunc: func(t require.TestingT, err error, _ ...interface{}) {
require.ErrorIs(t, err, ErrExceedNode)
},
},
{
name: "appId超出限制",
nodeId: 3,
apps: 33,
wantErrFunc: func(t require.TestingT, err error, _ ...interface{}) {
require.ErrorIs(t, err, ErrExceedApp)
},
},
{
name: "生成正常",
nodeId: 0,
apps: 6,
wantErrFunc: require.NoError,
},
}
for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
_, err := NewMeoyingIDGenerator(tt.nodeId, tt.apps)
tt.wantErrFunc(t, err)
})
assert.ElementsMatch(t, []uint{
0, 1, 2, 3,
}, appids)
})
}

}

func Test_Generate(t *testing.T) {
idmaker, err := NewCustomSnowFlake(1, 6)
idmaker, err := NewMeoyingIDGenerator(1, 6)
require.NoError(t, err)
ids := make([]int64, 0)
for i := 0; i < 6; i++ {
Expand All @@ -54,7 +68,7 @@ func Test_Generate(t *testing.T) {
}

func Test_GenerateAppId(t *testing.T) {
idmaker, err := NewCustomSnowFlake(1, 16)
idmaker, err := NewMeoyingIDGenerator(1, 16)
require.NoError(t, err)
testcases := []struct {
name string
Expand Down
5 changes: 2 additions & 3 deletions internal/user/internal/integration/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ import (
"context"
"database/sql"
"fmt"
"github.com/ecodeclub/webook/internal/pkg/ectx"
"net/http"
"testing"

"github.com/ecodeclub/webook/internal/pkg/middleware"

"github.com/ecodeclub/webook/internal/pkg/snowflake"
"github.com/ecodeclub/webook/internal/user"

Expand Down Expand Up @@ -1162,7 +1161,7 @@ func (s *HandlerWithAppTestSuite) TestFindOrCreateByWechat() {
},
before: func(t *testing.T) {},
ctx: func() context.Context {
return middleware.CtxWithAppId(context.Background(), uint(1))
return ectx.CtxWithAppId(context.Background(), uint(1))
},
after: func(t *testing.T) {
var u dao.User
Expand Down
Loading

0 comments on commit 0db346f

Please sign in to comment.