diff --git a/signal.go b/signal.go index 71d2204a0..1b8456899 100644 --- a/signal.go +++ b/signal.go @@ -46,8 +46,9 @@ func (sig ShutdownSignal) String() string { func newSignalReceivers() signalReceivers { return signalReceivers{ - notify: signal.Notify, - signals: make(chan os.Signal, 1), + notify: signal.Notify, + stopNotify: signal.Stop, + signals: make(chan os.Signal, 1), } } @@ -64,7 +65,8 @@ type signalReceivers struct { finished chan struct{} // this stub allows us to unit test signal relay functionality - notify func(c chan<- os.Signal, sig ...os.Signal) + notify func(c chan<- os.Signal, sig ...os.Signal) + stopNotify func(c chan<- os.Signal) // last will contain a pointer to the last ShutdownSignal received, or // nil if none, if a new channel is created by Wait or Done, this last @@ -118,6 +120,7 @@ func (recv *signalReceivers) Start(ctx context.Context) { func (recv *signalReceivers) Stop(ctx context.Context) error { recv.m.Lock() defer recv.m.Unlock() + recv.stopNotify(recv.signals) // if the relayer is not running; return nil error if !recv.running() { diff --git a/signal_test.go b/signal_test.go index 8e8030ae7..95d6fe458 100644 --- a/signal_test.go +++ b/signal_test.go @@ -100,6 +100,10 @@ func TestSignal(t *testing.T) { } }() } + var stopCalledTimes int + recv.stopNotify = func(ch chan<- os.Signal) { + stopCalledTimes++ + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() recv.Start(ctx) @@ -110,6 +114,7 @@ func TestSignal(t *testing.T) { sig := <-recv.Wait() require.Equal(t, syscall.SIGTERM, sig.Signal) require.NoError(t, recv.Stop(ctx)) + require.Equal(t, 1, stopCalledTimes) close(stub) }) })