diff --git a/cmd/main.go b/cmd/main.go index de92abca..4f00acf0 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -10,6 +10,7 @@ import ( "log" "net" "net/http" + "os" "strings" "time" @@ -71,29 +72,33 @@ func main() { flag.Parse() + sh := utils.NewShutdownHandler(2 * time.Second) + // Create KV store for persistence options := gomap.DefaultOptions options.Codec = utils.ProtoCodec{} // TODO: we can change to redis or badger at any given time store := gomap.NewStore(options) - defer func(store gokv.Store) { - err := store.Close() - if err != nil { - log.Panic(err) - } - }(store) + sh.AddGokvStore(store) - go runGatewayServer(grpcPort, httpPort) - runGrpcServer(grpcPort, useKvm, store, spdkAddress, qmpAddress, ctrlrDir, busesStr, tlsFiles) + runGrpcServer(grpcPort, useKvm, store, spdkAddress, qmpAddress, ctrlrDir, busesStr, tlsFiles, sh) + runGatewayServer(grpcPort, httpPort, sh) + + if err := sh.RunAndWait(); err != nil { + log.Printf("Bridge error: %v", err) + os.Exit(-1) + } + log.Print("Bridge successfully stopped") } -func runGrpcServer(grpcPort int, useKvm bool, store gokv.Store, spdkAddress, qmpAddress, ctrlrDir, busesStr, tlsFiles string) { +func runGrpcServer( + grpcPort int, + useKvm bool, + store gokv.Store, + spdkAddress, qmpAddress, ctrlrDir, busesStr, tlsFiles string, + sh *utils.ShutdownHandler) { tp := utils.InitTracerProvider("opi-spdk-bridge") - defer func() { - if err := tp.Shutdown(context.Background()); err != nil { - log.Panicf("Tracer Provider Shutdown: %v", err) - } - }() + sh.AddTraceProvider(tp) buses := splitBusesBySeparator(busesStr) @@ -171,16 +176,17 @@ func runGrpcServer(grpcPort int, useKvm bool, store gokv.Store, spdkAddress, qmp reflection.Register(s) - log.Printf("gRPC server listening at %v", lis.Addr()) - if err := s.Serve(lis); err != nil { - log.Panicf("failed to serve: %v", err) - } + sh.AddGrpcServer(s, lis) } -func runGatewayServer(grpcPort int, httpPort int) { +func runGatewayServer(grpcPort int, httpPort int, sh *utils.ShutdownHandler) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) - defer cancel() + sh.AddShutdown(func(_ context.Context) error { + log.Println("Canceling context to close HTTP gateway endpoint to gRPC server") + cancel() + return nil + }) // Register gRPC server endpoint // Note: Make sure the gRPC server is running properly and accessible @@ -192,15 +198,11 @@ func runGatewayServer(grpcPort int, httpPort int) { } // Start HTTP server (and proxy calls to gRPC server endpoint) - log.Printf("HTTP Server listening at %v", httpPort) server := &http.Server{ Addr: fmt.Sprintf(":%d", httpPort), Handler: mux, ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, } - err = server.ListenAndServe() - if err != nil { - log.Panic("cannot start HTTP gateway server") - } + sh.AddHTTPServer(server) } diff --git a/pkg/utils/shutdown.go b/pkg/utils/shutdown.go new file mode 100644 index 00000000..031c280a --- /dev/null +++ b/pkg/utils/shutdown.go @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (C) 2023 Intel Corporation + +// Package utils contains utility functions +package utils + +import ( + "context" + "errors" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/philippgille/gokv" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" +) + +// ServeFunc function to run service job +type ServeFunc func() error + +// ShutdownFunc function to perform shutdown of a service +type ShutdownFunc func(ctx context.Context) error + +// ShutdownHandler is responsible for running services and perform their shutdowns +// on service error or signals +type ShutdownHandler struct { + waitSignal chan os.Signal + timeoutPerShutdown time.Duration + + mu sync.Mutex + serves []ServeFunc + shutdowns []ShutdownFunc + + eg *errgroup.Group + egCtx context.Context +} + +// NewShutdownHandler creates an instance of ShutdownHandler +func NewShutdownHandler( + timeoutPerShutdown time.Duration, +) *ShutdownHandler { + eg, egCtx := errgroup.WithContext(context.Background()) + + return &ShutdownHandler{ + waitSignal: make(chan os.Signal, 1), + timeoutPerShutdown: timeoutPerShutdown, + + mu: sync.Mutex{}, + serves: []ServeFunc{}, + shutdowns: []ShutdownFunc{}, + + eg: eg, + egCtx: egCtx, + } +} + +// AddServe adds a service to run ant corresponding shutdown +func (s *ShutdownHandler) AddServe(serve ServeFunc, shutdown ShutdownFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.serves = append(s.serves, serve) + s.shutdowns = append(s.shutdowns, shutdown) +} + +// AddShutdown add a shutdown procedure to execute. +// Shutdowns are executed in backward order +func (s *ShutdownHandler) AddShutdown(shutdown ShutdownFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.shutdowns = append(s.shutdowns, shutdown) +} + +// AddGrpcServer adds serve and shutdown procedures for provided gRPC server +func (s *ShutdownHandler) AddGrpcServer(server *grpc.Server, lis net.Listener) { + s.AddServe( + func() error { + log.Printf("gRPC Server listening at %v", lis.Addr()) + return server.Serve(lis) + }, + func(ctx context.Context) error { + log.Println("Stopping gRPC Server") + return runWithCtx(ctx, func() error { + server.GracefulStop() + return nil + }) + }, + ) +} + +// AddHTTPServer adds serve and shutdown procedures for provided HTTP server +func (s *ShutdownHandler) AddHTTPServer(server *http.Server) { + s.AddServe( + func() error { + log.Printf("HTTP Server listening at %v", server.Addr) + err := server.ListenAndServe() + if errors.Is(err, http.ErrServerClosed) { + return nil + } + + return err + }, + func(ctx context.Context) error { + log.Println("Stopping HTTP Server") + err := server.Shutdown(ctx) + if err != nil { + cerr := server.Close() + log.Println("HTTP server close error:", cerr) + } + return err + }, + ) +} + +// AddGokvStore adds gokv shutdown procedure +func (s *ShutdownHandler) AddGokvStore(store gokv.Store) { + s.AddShutdown(func(ctx context.Context) error { + log.Println("Stopping gokv storage") + return runWithCtx(ctx, func() error { + return store.Close() + }) + }) +} + +// AddTraceProvider adds trace provider shutdown procedure +func (s *ShutdownHandler) AddTraceProvider(tp *sdktrace.TracerProvider) { + s.AddShutdown(func(ctx context.Context) error { + log.Println("Stopping tracer") + return tp.Shutdown(ctx) + }) +} + +// RunAndWait runs all services and execute shutdowns on a signal received +func (s *ShutdownHandler) RunAndWait() error { + for i := range s.serves { + fn := s.serves[i] + s.eg.Go(func() error { + return wrapServeFuncPanic(fn)() + }) + } + + s.eg.Go(func() error { + signal.Notify(s.waitSignal, syscall.SIGINT, syscall.SIGTERM) + select { + case sig := <-s.waitSignal: + log.Printf("Got signal: %v", sig) + case <-s.egCtx.Done(): + // can be reached if any Serve returned an error. Thus, initiating shutdown + log.Println("A process from errgroup exited with error:", s.egCtx.Err()) + } + log.Printf("Start graceful shutdown with timeout per shutdown call: %v", s.timeoutPerShutdown) + + s.mu.Lock() + defer s.mu.Unlock() + + var err error + for i := len(s.shutdowns) - 1; i >= 0; i-- { + timeoutCtx, cancel := context.WithTimeout(context.Background(), s.timeoutPerShutdown) + defer cancel() + shutdownFn := wrapShutdownFuncPanic(s.shutdowns[i]) + err = errors.Join(err, shutdownFn(timeoutCtx)) + } + + return err + }) + + return s.eg.Wait() +} + +func wrapServeFuncPanic(fn ServeFunc) ServeFunc { + return func() (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("was panic for serve function, recovered value: %v", r) + } + }() + err = fn() + return err + } +} + +func wrapShutdownFuncPanic(fn ShutdownFunc) ShutdownFunc { + return func(ctx context.Context) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("was panic for shutdown function, recovered value: %v", r) + } + }() + err = fn(ctx) + return err + } +} + +func runWithCtx(ctx context.Context, fn func() error) error { + var err error + + stopped := make(chan struct{}, 1) + go func() { + err = fn() + stopped <- struct{}{} + }() + + select { + case <-ctx.Done(): + err = ctx.Err() + case <-stopped: + } + + return err +} diff --git a/pkg/utils/shutdown_test.go b/pkg/utils/shutdown_test.go new file mode 100644 index 00000000..599fed4e --- /dev/null +++ b/pkg/utils/shutdown_test.go @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (C) 2023 Intel Corporation + +// Package utils contains utility functions +package utils + +import ( + "context" + "errors" + "log" + "os" + "sync" + "testing" + "time" + + "golang.org/x/exp/slices" +) + +type serveShutdownPair struct { + shutdownTrigger chan struct{} + serve ServeFunc + shutdown ShutdownFunc +} + +func newServeShutdownPair( + fnID int, + serveIDs, shutdownIDs *[]int, + mu *sync.Mutex, + serveErr, shutdownErr error, + servePanic bool, shutdownPanic bool, +) *serveShutdownPair { + shutdownTrigger := make(chan struct{}, 1) + s := &serveShutdownPair{ + shutdownTrigger: shutdownTrigger, + serve: func() error { + mu.Lock() + *serveIDs = append(*serveIDs, fnID) + mu.Unlock() + + if servePanic { + log.Panic("Panic!") + } + + if serveErr == nil { + <-shutdownTrigger + } + + return serveErr + }, + shutdown: func(ctx context.Context) error { + mu.Lock() + *shutdownIDs = append(*shutdownIDs, fnID) + mu.Unlock() + + shutdownTrigger <- struct{}{} + + if shutdownPanic { + log.Panic("Panic!") + } + + return shutdownErr + }, + } + + return s +} + +func errString(err error) string { + if err == nil { + return "" + } + + return err.Error() +} + +func TestRunAndWait(t *testing.T) { + stubErr := errors.New("stub error") + tests := map[string]struct { + giveServeErr error + giveServePanic bool + giveShutdownErr error + giveShutdownPanic bool + stoppedByInterrupt bool + wantErr string + }{ + "all services successfully completed": { + giveServeErr: nil, + giveServePanic: false, + giveShutdownErr: nil, + giveShutdownPanic: false, + stoppedByInterrupt: true, + wantErr: "", + }, + "serve failed": { + giveServeErr: stubErr, + giveServePanic: false, + giveShutdownErr: nil, + stoppedByInterrupt: false, + giveShutdownPanic: false, + wantErr: stubErr.Error(), + }, + "shutdown failed": { + giveServeErr: nil, + giveServePanic: false, + giveShutdownErr: stubErr, + giveShutdownPanic: false, + stoppedByInterrupt: true, + wantErr: stubErr.Error(), + }, + "serve panic": { + giveServeErr: nil, + giveServePanic: true, + giveShutdownErr: nil, + giveShutdownPanic: false, + stoppedByInterrupt: false, + wantErr: "was panic for serve function, recovered value: Panic!", + }, + "shutdown panic": { + giveServeErr: nil, + giveServePanic: false, + giveShutdownErr: nil, + giveShutdownPanic: true, + stoppedByInterrupt: true, + wantErr: "was panic for shutdown function, recovered value: Panic!", + }, + } + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + sh := NewShutdownHandler(1 * time.Millisecond) + + serveFnIDs := &[]int{} + shutdownFnIDs := &[]int{} + mu := sync.Mutex{} + s0 := newServeShutdownPair(0, serveFnIDs, shutdownFnIDs, &mu, nil, nil, false, false) + s1 := newServeShutdownPair(1, serveFnIDs, shutdownFnIDs, &mu, + tt.giveServeErr, tt.giveShutdownErr, + tt.giveServePanic, tt.giveShutdownPanic, + ) + s2 := newServeShutdownPair(2, serveFnIDs, shutdownFnIDs, &mu, nil, nil, false, false) + + sh.AddServe(s0.serve, s0.shutdown) + sh.AddServe(s1.serve, s1.shutdown) + sh.AddServe(s2.serve, s2.shutdown) + + if tt.stoppedByInterrupt { + sh.waitSignal <- os.Interrupt + } + + err := sh.RunAndWait() + + if errString(err) != tt.wantErr { + t.Errorf("Expected error: %v, received: %v", tt.wantErr, err) + } + + if !slices.Equal(*shutdownFnIDs, []int{2, 1, 0}) { + t.Errorf("Expected shutdown functions are called in order, instead %v", shutdownFnIDs) + } + + slices.Sort(*serveFnIDs) + if !slices.Equal(*serveFnIDs, []int{0, 1, 2}) { + t.Errorf("Expected all serve functions are called, instead %v", serveFnIDs) + } + }) + } +}