diff --git a/main.go b/main.go index 8390607..f06be0c 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" + "context" "os" "os/signal" "syscall" @@ -15,10 +16,11 @@ import ( ) var ( - listen string - quotaDir string - users string - timeout time.Duration + listen string + quotaDir string + users string + timeout time.Duration + gracefulPeriod time.Duration startTime = time.Now() lastReloadTime = time.Now() @@ -74,6 +76,7 @@ func init() { flag.StringVar("aDir, "quotaDir", "quota", "quota directory") flag.StringVar(&users, "users", ".htpasswd", "htpasswd auth file path") flag.DurationVar(&timeout, "timeout", 300*time.Second, "session creation timeout in time.Duration format, e.g. 300s or 500ms") + flag.DurationVar(&gracefulPeriod, "graceful-period", 300*time.Second, "graceful shutdown period in time.Duration format, e.g. 300s or 500ms") flag.BoolVar(&version, "version", false, "show version and exit") flag.Parse() if version { @@ -95,5 +98,20 @@ func init() { } func main() { - log.Fatal(http.ListenAndServe(listen, mux())) + stop := make(chan os.Signal) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + + server := &http.Server{ + Addr: listen, + Handler: mux(), + } + go server.ListenAndServe() + + <-stop + + ctx, cancel := context.WithTimeout(context.Background(), gracefulPeriod) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + log.Fatalf("graceful shutdown: %v\n", err) + } }