diff --git a/trace/icmp_ipv4.go b/trace/icmp_ipv4.go index d6d61262..0d796886 100644 --- a/trace/icmp_ipv4.go +++ b/trace/icmp_ipv4.go @@ -10,27 +10,17 @@ import ( "golang.org/x/net/context" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" - "golang.org/x/sync/semaphore" ) type ICMPTracer struct { Config - wg sync.WaitGroup - res Result - ctx context.Context - inflightRequest map[int]chan Hop - inflightRequestLock sync.Mutex - icmpListen net.PacketConn - workFork workFork - final int - finalLock sync.Mutex - - sem *semaphore.Weighted -} - -type workFork struct { - ttl int - num int + wg sync.WaitGroup + res Result + ctx context.Context + resCh chan Hop + icmpListen net.PacketConn + final int + finalLock sync.Mutex } func (t *ICMPTracer) Execute() (*Result, error) { @@ -49,24 +39,24 @@ func (t *ICMPTracer) Execute() (*Result, error) { var cancel context.CancelFunc t.ctx, cancel = context.WithCancel(context.Background()) defer cancel() - t.inflightRequest = make(map[int]chan Hop) + t.resCh = make(chan Hop) t.final = -1 go t.listenICMP() - t.sem = semaphore.NewWeighted(int64(t.ParallelRequests)) - - for t.workFork.ttl = 1; t.workFork.ttl <= t.MaxHops; t.workFork.ttl++ { + for ttl := 1; ttl <= t.MaxHops; ttl++ { + if t.final != -1 && ttl > t.final { + break + } for i := 0; i < t.NumMeasurements; i++ { t.wg.Add(1) - go t.send(workFork{t.workFork.ttl, i}) + go t.send(ttl) } // 一组TTL全部退出(收到应答或者超时终止)以后,再进行下一个TTL的包发送 t.wg.Wait() if t.RealtimePrinter != nil { - t.RealtimePrinter(&t.res, t.workFork.ttl-1) + t.RealtimePrinter(&t.res, ttl-1) } - t.workFork.num = 0 } t.res.reduce(t.final) @@ -103,29 +93,15 @@ func (t *ICMPTracer) listenICMP() { } func (t *ICMPTracer) handleICMPMessage(msg ReceivedMessage, icmpType int8, data []byte) { - - t.inflightRequestLock.Lock() - defer t.inflightRequestLock.Unlock() - ch, ok := t.inflightRequest[t.workFork.num] - t.workFork.num += 1 - if !ok { - return - } - ch <- Hop{ + t.resCh <- Hop{ Success: true, Address: msg.Peer, } } -func (t *ICMPTracer) send(fork workFork) error { - err := t.sem.Acquire(context.Background(), 1) - if err != nil { - return err - } - defer t.sem.Release(1) - +func (t *ICMPTracer) send(ttl int) error { defer t.wg.Done() - if t.final != -1 && fork.ttl > t.final { + if t.final != -1 && ttl > t.final { return nil } @@ -137,7 +113,7 @@ func (t *ICMPTracer) send(fork workFork) error { }, } - ipv4.NewPacketConn(t.icmpListen).SetTTL(fork.ttl) + ipv4.NewPacketConn(t.icmpListen).SetTTL(ttl) wb, err := icmpHeader.Marshal(nil) if err != nil { @@ -151,41 +127,30 @@ func (t *ICMPTracer) send(fork workFork) error { if err := t.icmpListen.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { log.Fatal(err) } - t.inflightRequestLock.Lock() - hopCh := make(chan Hop) - t.inflightRequest[fork.num] = hopCh - t.inflightRequestLock.Unlock() - - // defer func() { - // t.inflightRequestLock.Lock() - // close(hopCh) - // delete(t.inflightRequest, fork.ttl) - // t.inflightRequestLock.Unlock() - // }() select { case <-t.ctx.Done(): return nil - case h := <-hopCh: + case h := <-t.resCh: rtt := time.Since(start) - if t.final != -1 && fork.ttl > t.final { + if t.final != -1 && ttl > t.final { return nil } if addr, ok := h.Address.(*net.IPAddr); ok && addr.IP.Equal(t.DestIP) { t.finalLock.Lock() - if t.final == -1 || fork.ttl < t.final { - t.final = fork.ttl + if t.final == -1 || ttl < t.final { + t.final = ttl } t.finalLock.Unlock() } else if addr, ok := h.Address.(*net.TCPAddr); ok && addr.IP.Equal(t.DestIP) { t.finalLock.Lock() - if t.final == -1 || fork.ttl < t.final { - t.final = fork.ttl + if t.final == -1 || ttl < t.final { + t.final = ttl } t.finalLock.Unlock() } - h.TTL = fork.ttl + h.TTL = ttl h.RTT = rtt h.fetchIPData(t.Config) @@ -193,14 +158,14 @@ func (t *ICMPTracer) send(fork workFork) error { t.res.add(h) case <-time.After(t.Timeout): - if t.final != -1 && fork.ttl > t.final { + if t.final != -1 && ttl > t.final { return nil } t.res.add(Hop{ Success: false, Address: nil, - TTL: fork.ttl, + TTL: ttl, RTT: 0, Error: ErrHopLimitTimeout, }) diff --git a/trace/icmp_ipv6.go b/trace/icmp_ipv6.go index 0f365777..d6849d44 100644 --- a/trace/icmp_ipv6.go +++ b/trace/icmp_ipv6.go @@ -10,22 +10,17 @@ import ( "golang.org/x/net/context" "golang.org/x/net/icmp" "golang.org/x/net/ipv6" - "golang.org/x/sync/semaphore" ) type ICMPTracerv6 struct { Config - wg sync.WaitGroup - res Result - ctx context.Context - inflightRequest map[int]chan Hop - inflightRequestLock sync.Mutex - icmpListen net.PacketConn - workFork workFork - final int - finalLock sync.Mutex - - sem *semaphore.Weighted + wg sync.WaitGroup + res Result + ctx context.Context + resCh chan Hop + icmpListen net.PacketConn + final int + finalLock sync.Mutex } func (t *ICMPTracerv6) Execute() (*Result, error) { @@ -44,21 +39,24 @@ func (t *ICMPTracerv6) Execute() (*Result, error) { var cancel context.CancelFunc t.ctx, cancel = context.WithCancel(context.Background()) defer cancel() - t.inflightRequest = make(map[int]chan Hop) + t.resCh = make(chan Hop) t.final = -1 go t.listenICMP() - t.sem = semaphore.NewWeighted(int64(t.ParallelRequests)) - - for t.workFork.ttl = 1; t.workFork.ttl <= t.MaxHops; t.workFork.ttl++ { + for ttl := 1; ttl <= t.MaxHops; ttl++ { + if t.final != -1 && ttl > t.final { + break + } for i := 0; i < t.NumMeasurements; i++ { t.wg.Add(1) - go t.send(workFork{t.workFork.ttl, i}) + go t.send(ttl) } // 一组TTL全部退出(收到应答或者超时终止)以后,再进行下一个TTL的包发送 t.wg.Wait() - t.workFork.num = 0 + if t.RealtimePrinter != nil { + t.RealtimePrinter(&t.res, ttl-1) + } } t.res.reduce(t.final) @@ -96,29 +94,15 @@ func (t *ICMPTracerv6) listenICMP() { } func (t *ICMPTracerv6) handleICMPMessage(msg ReceivedMessage, icmpType int8, data []byte) { - t.inflightRequestLock.Lock() - defer t.inflightRequestLock.Unlock() - ch, ok := t.inflightRequest[t.workFork.num] - t.workFork.num += 1 - if !ok { - return - } - ch <- Hop{ + t.resCh <- Hop{ Success: true, Address: msg.Peer, } } -func (t *ICMPTracerv6) send(fork workFork) error { - err := t.sem.Acquire(context.Background(), 1) - if err != nil { - return err - } - - defer t.sem.Release(1) - +func (t *ICMPTracerv6) send(ttl int) error { defer t.wg.Done() - if t.final != -1 && fork.ttl > t.final { + if t.final != -1 && ttl > t.final { return nil } @@ -132,8 +116,8 @@ func (t *ICMPTracerv6) send(fork workFork) error { p := ipv6.NewPacketConn(t.icmpListen) - icmpHeader.Body.(*icmp.Echo).Seq = fork.ttl - p.SetHopLimit(fork.ttl) + icmpHeader.Body.(*icmp.Echo).Seq = ttl + p.SetHopLimit(ttl) wb, err := icmpHeader.Marshal(nil) if err != nil { @@ -147,41 +131,30 @@ func (t *ICMPTracerv6) send(fork workFork) error { if err := t.icmpListen.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { log.Fatal(err) } - t.inflightRequestLock.Lock() - hopCh := make(chan Hop) - t.inflightRequest[fork.num] = hopCh - t.inflightRequestLock.Unlock() - - // defer func() { - // t.inflightRequestLock.Lock() - // close(hopCh) - // delete(t.inflightRequest, fork.ttl) - // t.inflightRequestLock.Unlock() - // }() select { case <-t.ctx.Done(): return nil - case h := <-hopCh: + case h := <-t.resCh: rtt := time.Since(start) - if t.final != -1 && fork.ttl > t.final { + if t.final != -1 && ttl > t.final { return nil } if addr, ok := h.Address.(*net.IPAddr); ok && addr.IP.Equal(t.DestIP) { t.finalLock.Lock() - if t.final == -1 || fork.ttl < t.final { - t.final = fork.ttl + if t.final == -1 || ttl < t.final { + t.final = ttl } t.finalLock.Unlock() } else if addr, ok := h.Address.(*net.TCPAddr); ok && addr.IP.Equal(t.DestIP) { t.finalLock.Lock() - if t.final == -1 || fork.ttl < t.final { - t.final = fork.ttl + if t.final == -1 || ttl < t.final { + t.final = ttl } t.finalLock.Unlock() } - h.TTL = fork.ttl + h.TTL = ttl h.RTT = rtt h.fetchIPData(t.Config) @@ -189,14 +162,14 @@ func (t *ICMPTracerv6) send(fork workFork) error { t.res.add(h) case <-time.After(t.Timeout): - if t.final != -1 && fork.ttl > t.final { + if t.final != -1 && ttl > t.final { return nil } t.res.add(Hop{ Success: false, Address: nil, - TTL: fork.ttl, + TTL: ttl, RTT: 0, Error: ErrHopLimitTimeout, })