Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support the ability to sense client disconnection #972

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/app/server/binding/internal/decoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder
}, needValidate, nil
}

func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) {
func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) {
for field.Type.Kind() == reflect.Ptr {
field.Type = field.Type.Elem()
}
Expand Down
18 changes: 18 additions & 0 deletions pkg/app/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,24 @@ func WithDisablePrintRoute(b bool) config.Option {
}}
}

// WithSenseClientDisconnection sets the ability to sense client disconnections.
// If we don't set it, it will default to false.
// There are three issues to note when using this option:
// 1. It only applies to netpoll.
// 2. It needs to be used in conjunction with WithOnAccept.
// Examples:
// server.Default(
// server.WithSenseClientDisconnection(true),
// server.WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context {
// return ctx
// }))
// 3. The cost is high after opening, please choose carefully.
func WithSenseClientDisconnection(b bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.SenseClientDisconnection = b
}}
}

// WithOnAccept sets the callback function when a new connection is accepted but cannot
// receive data in netpoll. In go net, it will be called before converting tls connection
func WithOnAccept(fn func(conn net.Conn) context.Context) config.Option {
Expand Down
3 changes: 3 additions & 0 deletions pkg/app/server/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func TestOptions(t *testing.T) {
WithBasePath("/"),
WithMaxRequestBodySize(2),
WithDisablePrintRoute(true),
WithSenseClientDisconnection(true),
WithNetwork("unix"),
WithExitWaitTime(time.Second),
WithMaxKeepBodySize(500),
Expand Down Expand Up @@ -93,6 +94,7 @@ func TestOptions(t *testing.T) {
assert.DeepEqual(t, opt.BasePath, "/")
assert.DeepEqual(t, opt.MaxRequestBodySize, 2)
assert.DeepEqual(t, opt.DisablePrintRoute, true)
assert.DeepEqual(t, opt.SenseClientDisconnection, true)
assert.DeepEqual(t, opt.Network, "unix")
assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second)
assert.DeepEqual(t, opt.MaxKeepBodySize, 500)
Expand Down Expand Up @@ -130,6 +132,7 @@ func TestDefaultOptions(t *testing.T) {
assert.DeepEqual(t, opt.GetOnly, false)
assert.DeepEqual(t, opt.DisableKeepalive, false)
assert.DeepEqual(t, opt.DisablePrintRoute, false)
assert.DeepEqual(t, opt.SenseClientDisconnection, false)
assert.DeepEqual(t, opt.Network, "tcp")
assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second*5)
assert.DeepEqual(t, opt.MaxKeepBodySize, 4*1024*1024)
Expand Down
4 changes: 4 additions & 0 deletions pkg/common/config/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type Options struct {
StreamRequestBody bool
NoDefaultServerHeader bool
DisablePrintRoute bool
SenseClientDisconnection bool
Network string
Addr string
BasePath string
Expand Down Expand Up @@ -195,6 +196,9 @@ func NewOptions(opts []Option) *Options {
// Disabled when set to True
DisablePrintRoute: false,

// The ability to sense client disconnection is disabled by default
SenseClientDisconnection: false,

// "tcp", "udp", "unix"(unix domain socket)
Network: defaultNetwork,

Expand Down
1 change: 1 addition & 0 deletions pkg/common/config/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func TestDefaultOptions(t *testing.T) {
assert.False(t, options.HandleMethodNotAllowed)
assert.False(t, options.UseRawPath)
assert.False(t, options.RemoveExtraSlash)
assert.False(t, options.SenseClientDisconnection)
assert.True(t, options.UnescapePathValues)
assert.False(t, options.DisablePreParseMultipartForm)
assert.DeepEqual(t, defaultNetwork, options.Network)
Expand Down
50 changes: 30 additions & 20 deletions pkg/network/netpoll/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,33 @@ func init() {

type transporter struct {
sync.RWMutex
network string
addr string
keepAliveTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
listener net.Listener
eventLoop netpoll.EventLoop
listenConfig *net.ListenConfig
OnAccept func(conn net.Conn) context.Context
OnConnect func(ctx context.Context, conn network.Conn) context.Context
network string
addr string
senseClientDisconnection bool
keepAliveTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
listener net.Listener
eventLoop netpoll.EventLoop
listenConfig *net.ListenConfig
OnAccept func(conn net.Conn) context.Context
OnConnect func(ctx context.Context, conn network.Conn) context.Context
}

// For transporter switch
func NewTransporter(options *config.Options) network.Transporter {
return &transporter{
network: options.Network,
addr: options.Addr,
keepAliveTimeout: options.KeepAliveTimeout,
readTimeout: options.ReadTimeout,
writeTimeout: options.WriteTimeout,
listener: nil,
eventLoop: nil,
listenConfig: options.ListenConfig,
OnAccept: options.OnAccept,
OnConnect: options.OnConnect,
network: options.Network,
addr: options.Addr,
senseClientDisconnection: options.SenseClientDisconnection,
keepAliveTimeout: options.KeepAliveTimeout,
readTimeout: options.ReadTimeout,
writeTimeout: options.WriteTimeout,
listener: nil,
eventLoop: nil,
listenConfig: options.ListenConfig,
OnAccept: options.OnAccept,
OnConnect: options.OnConnect,
}
}

Expand Down Expand Up @@ -97,6 +99,14 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {

if t.OnConnect != nil {
opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, conn netpoll.Connection) context.Context {
if t.senseClientDisconnection {
li-jin-gou marked this conversation as resolved.
Show resolved Hide resolved
ctx, cancel := context.WithCancel(ctx)
conn.AddCloseCallback(func(connection netpoll.Connection) error {
cancel()
return nil
})
return t.OnConnect(ctx, newConn(conn))
}
return t.OnConnect(ctx, newConn(conn))
}))
}
Expand Down
27 changes: 27 additions & 0 deletions pkg/network/netpoll/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,33 @@ func TestTransport(t *testing.T) {
assert.Assert(t, atomic.LoadInt32(&onDataFlag) == 1)
})

t.Run("TestSenseClientDisconnection", func(t *testing.T) {
var onConnFlag int32
transporter := NewTransporter(&config.Options{
Addr: addr,
Network: nw,
SenseClientDisconnection: true,
OnConnect: func(ctx context.Context, conn network.Conn) context.Context {
atomic.StoreInt32(&onConnFlag, 1)
return ctx
},
})
go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error {
return nil
li-jin-gou marked this conversation as resolved.
Show resolved Hide resolved
})
defer transporter.Close()
time.Sleep(100 * time.Millisecond)

dial := NewDialer()
conn, err := dial.DialConnection(nw, addr, time.Second, nil)
assert.Nil(t, err)
_, err = conn.Write([]byte("456"))
assert.Nil(t, err)
time.Sleep(100 * time.Millisecond)

assert.Assert(t, atomic.LoadInt32(&onConnFlag) == 1)
})

t.Run("TestListenConfig", func(t *testing.T) {
listenCfg := &net.ListenConfig{Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
Expand Down