diff --git a/ipscanner/internal/engine/engine.go b/ipscanner/internal/engine/engine.go index 3af6ab600..e75911290 100644 --- a/ipscanner/internal/engine/engine.go +++ b/ipscanner/internal/engine/engine.go @@ -2,9 +2,9 @@ package engine import ( "context" + "errors" "log/slog" "net/netip" - "time" "github.com/bepass-org/warp-plus/ipscanner/internal/iterator" "github.com/bepass-org/warp-plus/ipscanner/internal/ping" @@ -14,7 +14,7 @@ import ( type Engine struct { generator *iterator.IpGenerator ipQueue *IPQueue - ping func(netip.Addr) (statute.IPInfo, error) + ping func(context.Context, netip.Addr) (statute.IPInfo, error) log *slog.Logger } @@ -28,7 +28,7 @@ func NewScannerEngine(opts *statute.ScannerOptions) *Engine { ipQueue: queue, ping: p.DoPing, generator: iterator.NewIterator(opts), - log: opts.Logger.With(slog.String("subsystem", "scanner/engine")), + log: opts.Logger, } } @@ -40,37 +40,33 @@ func (e *Engine) GetAvailableIPs(desc bool) []statute.IPInfo { } func (e *Engine) Run(ctx context.Context) { - for { - select { - case <-ctx.Done(): + e.ipQueue.Init() + + select { + case <-ctx.Done(): + return + case <-e.ipQueue.available: + e.log.Debug("Started new scanning round") + batch, err := e.generator.NextBatch() + if err != nil { + e.log.Error("Error while generating IP: %v", err) return - case <-e.ipQueue.available: - e.log.Debug("Started new scanning round") - batch, err := e.generator.NextBatch() - if err != nil { - e.log.Error("Error while generating IP: %v", err) - // in case of disastrous error, to prevent resource draining wait for 2 seconds and try again - time.Sleep(2 * time.Second) - continue - } - for _, ip := range batch { - select { - case <-ctx.Done(): - return - default: - e.log.Debug("pinging IP", "addr", ip) - if ipInfo, err := e.ping(ip); err == nil { - e.log.Debug("ping success", "addr", ipInfo.AddrPort, "rtt", ipInfo.RTT) - e.ipQueue.Enqueue(ipInfo) - } else { + } + for _, ip := range batch { + select { + case <-ctx.Done(): + return + default: + ipInfo, err := e.ping(ctx, ip) + if err != nil { + if !errors.Is(err, context.Canceled) { e.log.Error("ping error", "addr", ip, "error", err) } + continue } + e.log.Debug("ping success", "addr", ipInfo.AddrPort, "rtt", ipInfo.RTT) + e.ipQueue.Enqueue(ipInfo) } - default: - e.log.Debug("calling expire") - e.ipQueue.Expire() - time.Sleep(200 * time.Millisecond) } } } diff --git a/ipscanner/internal/engine/queue.go b/ipscanner/internal/engine/queue.go index ec3a66186..88c918d2d 100644 --- a/ipscanner/internal/engine/queue.go +++ b/ipscanner/internal/engine/queue.go @@ -29,7 +29,7 @@ func NewIPQueue(opts *statute.ScannerOptions) *IPQueue { maxTTL: opts.IPQueueTTL, rttThreshold: opts.MaxDesirableRTT, available: make(chan struct{}, opts.IPQueueSize), - log: opts.Logger.With(slog.String("subsystem", "engine/queue")), + log: opts.Logger, reserved: reserved, } } @@ -122,15 +122,19 @@ func (q *IPQueue) Dequeue() (statute.IPInfo, bool) { return info, true } -func (q *IPQueue) Expire() { +func (q *IPQueue) Init() { q.mu.Lock() defer q.mu.Unlock() if !q.inIdealMode { - q.log.Debug("Expire: Not in ideal mode") q.available <- struct{}{} return } +} + +func (q *IPQueue) Expire() { + q.mu.Lock() + defer q.mu.Unlock() q.log.Debug("Expire: In ideal mode") defer func() { diff --git a/ipscanner/internal/ping/ping.go b/ipscanner/internal/ping/ping.go index 32865d0f0..65fde8475 100644 --- a/ipscanner/internal/ping/ping.go +++ b/ipscanner/internal/ping/ping.go @@ -1,6 +1,7 @@ package ping import ( + "context" "errors" "fmt" "net/netip" @@ -13,9 +14,9 @@ type Ping struct { } // DoPing performs a ping on the given IP address. -func (p *Ping) DoPing(ip netip.Addr) (statute.IPInfo, error) { +func (p *Ping) DoPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { if p.Options.SelectedOps&statute.HTTPPing > 0 { - res, err := p.httpPing(ip) + res, err := p.httpPing(ctx, ip) if err != nil { return statute.IPInfo{}, err } @@ -23,7 +24,7 @@ func (p *Ping) DoPing(ip netip.Addr) (statute.IPInfo, error) { return res, nil } if p.Options.SelectedOps&statute.TLSPing > 0 { - res, err := p.tlsPing(ip) + res, err := p.tlsPing(ctx, ip) if err != nil { return statute.IPInfo{}, err } @@ -31,7 +32,7 @@ func (p *Ping) DoPing(ip netip.Addr) (statute.IPInfo, error) { return res, nil } if p.Options.SelectedOps&statute.TCPPing > 0 { - res, err := p.tcpPing(ip) + res, err := p.tcpPing(ctx, ip) if err != nil { return statute.IPInfo{}, err } @@ -39,7 +40,7 @@ func (p *Ping) DoPing(ip netip.Addr) (statute.IPInfo, error) { return res, nil } if p.Options.SelectedOps&statute.QUICPing > 0 { - res, err := p.quicPing(ip) + res, err := p.quicPing(ctx, ip) if err != nil { return statute.IPInfo{}, err } @@ -47,7 +48,7 @@ func (p *Ping) DoPing(ip netip.Addr) (statute.IPInfo, error) { return res, nil } if p.Options.SelectedOps&statute.WARPPing > 0 { - res, err := p.warpPing(ip) + res, err := p.warpPing(ctx, ip) if err != nil { return statute.IPInfo{}, err } @@ -58,8 +59,9 @@ func (p *Ping) DoPing(ip netip.Addr) (statute.IPInfo, error) { return statute.IPInfo{}, errors.New("no ping operation selected") } -func (p *Ping) httpPing(ip netip.Addr) (statute.IPInfo, error) { +func (p *Ping) httpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { return p.calc( + ctx, NewHttpPing( ip, "GET", @@ -74,30 +76,30 @@ func (p *Ping) httpPing(ip netip.Addr) (statute.IPInfo, error) { ) } -func (p *Ping) warpPing(ip netip.Addr) (statute.IPInfo, error) { - return p.calc(NewWarpPing(ip, p.Options)) +func (p *Ping) warpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + return p.calc(ctx, NewWarpPing(ip, p.Options)) } -func (p *Ping) tlsPing(ip netip.Addr) (statute.IPInfo, error) { - return p.calc( +func (p *Ping) tlsPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + return p.calc(ctx, NewTlsPing(ip, p.Options.Hostname, p.Options.Port, p.Options), ) } -func (p *Ping) tcpPing(ip netip.Addr) (statute.IPInfo, error) { - return p.calc( +func (p *Ping) tcpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + return p.calc(ctx, NewTcpPing(ip, p.Options.Hostname, p.Options.Port, p.Options), ) } -func (p *Ping) quicPing(ip netip.Addr) (statute.IPInfo, error) { - return p.calc( +func (p *Ping) quicPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + return p.calc(ctx, NewQuicPing(ip, p.Options.Hostname, p.Options.Port, p.Options), ) } -func (p *Ping) calc(tp statute.IPing) (statute.IPInfo, error) { - pr := tp.Ping() +func (p *Ping) calc(ctx context.Context, tp statute.IPing) (statute.IPInfo, error) { + pr := tp.PingContext(ctx) err := pr.Error() if err != nil { return statute.IPInfo{}, err diff --git a/ipscanner/internal/ping/warp.go b/ipscanner/internal/ping/warp.go index a513497de..dc135f3d0 100644 --- a/ipscanner/internal/ping/warp.go +++ b/ipscanner/internal/ping/warp.go @@ -55,9 +55,10 @@ func (h *WarpPing) Ping() statute.IPingResult { return h.PingContext(context.Background()) } -func (h *WarpPing) PingContext(_ context.Context) statute.IPingResult { +func (h *WarpPing) PingContext(ctx context.Context) statute.IPingResult { addr := netip.AddrPortFrom(h.IP, warp.RandomWarpPort()) rtt, err := initiateHandshake( + ctx, addr, h.PrivateKey, h.PeerPublicKey, @@ -117,15 +118,21 @@ func ephemeralKeypair() (noise.DHKey, error) { }, nil } -func randomInt(min, max int) int { - nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min+1))) +func randomInt(min, max uint64) uint64 { + rangee := max - min + if rangee < 1 { + return 0 + } + + n, err := rand.Int(rand.Reader, big.NewInt(int64(rangee))) if err != nil { panic(err) } - return int(nBig.Int64()) + min + + return min + n.Uint64() } -func initiateHandshake(serverAddr netip.AddrPort, privateKeyBase64, peerPublicKeyBase64, presharedKeyBase64 string) (time.Duration, error) { +func initiateHandshake(ctx context.Context, serverAddr netip.AddrPort, privateKeyBase64, peerPublicKeyBase64, presharedKeyBase64 string) (time.Duration, error) { staticKeyPair, err := staticKeypair(privateKeyBase64) if err != nil { return 0, err @@ -209,19 +216,24 @@ func initiateHandshake(serverAddr netip.AddrPort, privateKeyBase64, peerPublicKe numPackets := randomInt(8, 15) randomPacket := make([]byte, 100) - for i := 0; i < numPackets; i++ { - packetSize := randomInt(40, 100) - _, err := rand.Read(randomPacket[:packetSize]) - if err != nil { - return 0, fmt.Errorf("error generating random packet: %w", err) + for i := uint64(0); i < numPackets; i++ { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + packetSize := randomInt(40, 100) + _, err := rand.Read(randomPacket[:packetSize]) + if err != nil { + return 0, fmt.Errorf("error generating random packet: %w", err) + } + + _, err = conn.Write(randomPacket[:packetSize]) + if err != nil { + return 0, fmt.Errorf("error sending random packet: %w", err) + } + + time.Sleep(time.Duration(randomInt(20, 250)) * time.Millisecond) } - - _, err = conn.Write(randomPacket[:packetSize]) - if err != nil { - return 0, fmt.Errorf("error sending random packet: %w", err) - } - - time.Sleep(time.Duration(randomInt(20, 250)) * time.Millisecond) } _, err = initiationPacket.WriteTo(conn)