-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
1,285 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package middleware | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"strconv" | ||
|
||
"github.com/ecodeclub/ginx" | ||
"github.com/gin-gonic/gin" | ||
"github.com/gotomicro/ego/core/elog" | ||
) | ||
|
||
type AddAppIdBuilder struct { | ||
} | ||
|
||
type AppContextType string | ||
|
||
const ( | ||
AppCtxKey AppContextType = "app" | ||
) | ||
|
||
func NewAddAppIdBuilder() *AddAppIdBuilder { | ||
return &AddAppIdBuilder{} | ||
} | ||
func (a *AddAppIdBuilder) Build() gin.HandlerFunc { | ||
return func(ctx *gin.Context) { | ||
gctx := &ginx.Context{Context: ctx} | ||
appid := ctx.GetHeader(string(AppCtxKey)) | ||
if appid != "" { | ||
c := ctx.Request.Context() | ||
app, err := strconv.Atoi(appid) | ||
if err != nil { | ||
gctx.AbortWithStatus(http.StatusBadRequest) | ||
elog.Error("appid设置失败", elog.FieldErr(err)) | ||
return | ||
} | ||
newCtx := CtxWithAppId(c, uint(app)) | ||
ctx.Request = ctx.Request.WithContext(newCtx) | ||
} | ||
} | ||
} | ||
|
||
func AppID(ctx context.Context) (uint, bool) { | ||
app := ctx.Value(AppCtxKey) | ||
if app == nil { | ||
return 0, false | ||
} | ||
v, ok := app.(uint) | ||
return v, ok | ||
} | ||
|
||
func CtxWithAppId(ctx context.Context, appid uint) context.Context { | ||
return context.WithValue(ctx, AppCtxKey, appid) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package middleware | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/gin-gonic/gin" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestAddAppId(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
wantCode int | ||
before func(t *testing.T, ctx *gin.Context) | ||
afterFunc func(t *testing.T, ctx *gin.Context) | ||
}{ | ||
{ | ||
name: "appid 为 1", | ||
wantCode: 200, | ||
before: func(t *testing.T, ctx *gin.Context) { | ||
header := make(http.Header) | ||
header.Set(string(AppCtxKey), "1") | ||
ctx.Request = httptest.NewRequest(http.MethodPost, "/users/profile", nil) | ||
ctx.Request.Header = header | ||
}, | ||
afterFunc: func(t *testing.T, ctx *gin.Context) { | ||
c := ctx.Request.Context() | ||
v := c.Value(AppCtxKey) | ||
res, ok := v.(uint) | ||
require.True(t, ok) | ||
assert.Equal(t, uint(1), res) | ||
}, | ||
}, | ||
{ | ||
name: "appid没设置", | ||
wantCode: 200, | ||
before: func(t *testing.T, ctx *gin.Context) { | ||
header := make(http.Header) | ||
ctx.Request = httptest.NewRequest(http.MethodPost, "/users/profile", nil) | ||
ctx.Request.Header = header | ||
}, | ||
afterFunc: func(t *testing.T, ctx *gin.Context) { | ||
c := ctx.Request.Context() | ||
v := c.Value(AppCtxKey) | ||
require.Nil(t, v) | ||
}, | ||
}, | ||
{ | ||
name: "appid 设置为不是数字", | ||
wantCode: 400, | ||
before: func(t *testing.T, ctx *gin.Context) { | ||
header := make(http.Header) | ||
header.Set(string(AppCtxKey), "dasdsa") | ||
ctx.Request = httptest.NewRequest(http.MethodPost, "/users/profile", nil) | ||
ctx.Request.Header = header | ||
}, | ||
afterFunc: func(t *testing.T, ctx *gin.Context) { | ||
}, | ||
}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
w := httptest.NewRecorder() | ||
c, _ := gin.CreateTestContext(w) | ||
tc.before(t, c) | ||
builder := NewAddAppIdBuilder() | ||
hdl := builder.Build() | ||
hdl(c) | ||
assert.Equal(t, tc.wantCode, c.Writer.Status()) | ||
tc.afterFunc(t, c) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package snowflake | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
|
||
"github.com/bwmarrin/snowflake" | ||
"github.com/ecodeclub/ekit/syncx" | ||
) | ||
|
||
type SnowFlake interface { | ||
Generate(appid uint) (ID, error) | ||
} | ||
|
||
type CustomSnowFlake struct { | ||
// 键为appid | ||
nodes *syncx.Map[uint, *snowflake.Node] | ||
} | ||
|
||
const ( | ||
maxNode uint = 31 | ||
maxApp uint = 31 | ||
) | ||
|
||
var ( | ||
ErrExceedNode = errors.New("node超出限制") | ||
ErrExceedApp = errors.New("app超出限制") | ||
ErrUnknownApp = errors.New("未知的app") | ||
) | ||
|
||
// +---------------------------------------------------------------------------------------+ | ||
// | 1 Bit Unused | 41 Bit Timestamp | 5 Bit APPID | 5 Bit NodeID | 12 Bit Sequence ID | | ||
// +---------------------------------------------------------------------------------------+ | ||
|
||
// node表示第几个节点,appid表示有几个应用 从0开始排序 0-ietls 最多到31 | ||
func NewCustomSnowFlake(nodeId uint, apps uint) (*CustomSnowFlake, error) { | ||
nodeMap := &syncx.Map[uint, *snowflake.Node]{} | ||
if nodeId > maxNode { | ||
return nil, fmt.Errorf("%w", ErrExceedNode) | ||
} | ||
if apps > maxApp+1 { | ||
return nil, fmt.Errorf("%w", ErrExceedApp) | ||
} | ||
for i := 0; i < int(apps); i++ { | ||
nid := (i << 5) | int(nodeId) | ||
n, err := snowflake.NewNode(int64(nid)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
nodeMap.Store(uint(i), n) | ||
} | ||
return &CustomSnowFlake{ | ||
nodes: nodeMap, | ||
}, nil | ||
|
||
} | ||
|
||
type ID int64 | ||
|
||
func (c *CustomSnowFlake) Generate(appid uint) (ID, error) { | ||
n, ok := c.nodes.Load(appid) | ||
if !ok { | ||
return 0, fmt.Errorf("%w", ErrUnknownApp) | ||
} | ||
id := n.Generate() | ||
return ID(id), nil | ||
} | ||
|
||
func (f ID) AppID() uint { | ||
node := snowflake.ID(f).Node() | ||
return uint(node >> 5) | ||
} | ||
|
||
func (f ID) Int64() int64 { | ||
return int64(f) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package snowflake | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/bwmarrin/snowflake" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func Test_NewGenerate(t *testing.T) { | ||
t.Run("nodeId超出限制", func(t *testing.T) { | ||
_, err := NewCustomSnowFlake(32, 6) | ||
require.ErrorIs(t, err, ErrExceedNode) | ||
}) | ||
t.Run("app数量超出限制", func(t *testing.T) { | ||
_, err := NewCustomSnowFlake(3, 33) | ||
require.ErrorIs(t, err, ErrExceedApp) | ||
}) | ||
t.Run("正常生成", func(t *testing.T) { | ||
idMaker, err := NewCustomSnowFlake(0, 4) | ||
require.NoError(t, err) | ||
appids := make([]uint, 0, 4) | ||
idMaker.nodes.Range(func(key uint, value *snowflake.Node) bool { | ||
appids = append(appids, key) | ||
return true | ||
}) | ||
assert.ElementsMatch(t, []uint{ | ||
0, 1, 2, 3, | ||
}, appids) | ||
}) | ||
|
||
} | ||
|
||
func Test_Generate(t *testing.T) { | ||
idmaker, err := NewCustomSnowFlake(1, 6) | ||
require.NoError(t, err) | ||
ids := make([]int64, 0) | ||
for i := 0; i < 6; i++ { | ||
for j := 0; j < 100000; j++ { | ||
id, err := idmaker.Generate(uint(i)) | ||
require.NoError(t, err) | ||
ids = append(ids, id.Int64()) | ||
} | ||
} | ||
// 校验生成的id是否重复 | ||
idmap := make(map[int64]struct{}, len(ids)) | ||
for i := 0; i < len(ids); i++ { | ||
_, ok := idmap[ids[i]] | ||
require.False(t, ok) | ||
idmap[ids[i]] = struct{}{} | ||
} | ||
|
||
} | ||
|
||
func Test_GenerateAppId(t *testing.T) { | ||
idmaker, err := NewCustomSnowFlake(1, 16) | ||
require.NoError(t, err) | ||
testcases := []struct { | ||
name string | ||
appid uint | ||
wantErr require.ErrorAssertionFunc | ||
}{ | ||
{ | ||
name: "appId没找到", | ||
appid: 16, | ||
wantErr: func(t require.TestingT, err error, i ...interface{}) { | ||
require.ErrorIs(t, err, ErrUnknownApp) | ||
}, | ||
}, | ||
{ | ||
name: "appid 为1", | ||
appid: 1, | ||
wantErr: require.NoError, | ||
}, | ||
} | ||
for _, tc := range testcases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
id, err := idmaker.Generate(tc.appid) | ||
tc.wantErr(t, err) | ||
if err != nil { | ||
return | ||
} | ||
app := id.AppID() | ||
assert.Equal(t, tc.appid, app) | ||
}) | ||
} | ||
} |
Oops, something went wrong.