Skip to content

Commit

Permalink
ipscanner: slim down scanner and use ctx everywhere
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Pashmfouroush <[email protected]>
  • Loading branch information
markpash committed Aug 3, 2024
1 parent c02b99a commit 6c98f7a
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 66 deletions.
54 changes: 25 additions & 29 deletions ipscanner/internal/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package engine

import (
"context"
"errors"
"log/slog"
"net/netip"
"time"

"github.com/bepass-org/warp-plus/ipscanner/internal/iterator"
"github.com/bepass-org/warp-plus/ipscanner/internal/ping"
Expand All @@ -14,7 +14,7 @@ import (
type Engine struct {
generator *iterator.IpGenerator
ipQueue *IPQueue
ping func(netip.Addr) (statute.IPInfo, error)
ping func(context.Context, netip.Addr) (statute.IPInfo, error)
log *slog.Logger
}

Expand All @@ -28,7 +28,7 @@ func NewScannerEngine(opts *statute.ScannerOptions) *Engine {
ipQueue: queue,
ping: p.DoPing,
generator: iterator.NewIterator(opts),
log: opts.Logger.With(slog.String("subsystem", "scanner/engine")),
log: opts.Logger,
}
}

Expand All @@ -40,37 +40,33 @@ func (e *Engine) GetAvailableIPs(desc bool) []statute.IPInfo {
}

func (e *Engine) Run(ctx context.Context) {
for {
select {
case <-ctx.Done():
e.ipQueue.Init()

select {
case <-ctx.Done():
return
case <-e.ipQueue.available:
e.log.Debug("Started new scanning round")
batch, err := e.generator.NextBatch()
if err != nil {
e.log.Error("Error while generating IP: %v", err)
return
case <-e.ipQueue.available:
e.log.Debug("Started new scanning round")
batch, err := e.generator.NextBatch()
if err != nil {
e.log.Error("Error while generating IP: %v", err)
// in case of disastrous error, to prevent resource draining wait for 2 seconds and try again
time.Sleep(2 * time.Second)
continue
}
for _, ip := range batch {
select {
case <-ctx.Done():
return
default:
e.log.Debug("pinging IP", "addr", ip)
if ipInfo, err := e.ping(ip); err == nil {
e.log.Debug("ping success", "addr", ipInfo.AddrPort, "rtt", ipInfo.RTT)
e.ipQueue.Enqueue(ipInfo)
} else {
}
for _, ip := range batch {
select {
case <-ctx.Done():
return
default:
ipInfo, err := e.ping(ctx, ip)
if err != nil {
if !errors.Is(err, context.Canceled) {
e.log.Error("ping error", "addr", ip, "error", err)
}
continue
}
e.log.Debug("ping success", "addr", ipInfo.AddrPort, "rtt", ipInfo.RTT)
e.ipQueue.Enqueue(ipInfo)
}
default:
e.log.Debug("calling expire")
e.ipQueue.Expire()
time.Sleep(200 * time.Millisecond)
}
}
}
10 changes: 7 additions & 3 deletions ipscanner/internal/engine/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func NewIPQueue(opts *statute.ScannerOptions) *IPQueue {
maxTTL: opts.IPQueueTTL,
rttThreshold: opts.MaxDesirableRTT,
available: make(chan struct{}, opts.IPQueueSize),
log: opts.Logger.With(slog.String("subsystem", "engine/queue")),
log: opts.Logger,
reserved: reserved,
}
}
Expand Down Expand Up @@ -122,15 +122,19 @@ func (q *IPQueue) Dequeue() (statute.IPInfo, bool) {
return info, true
}

func (q *IPQueue) Expire() {
func (q *IPQueue) Init() {
q.mu.Lock()
defer q.mu.Unlock()

if !q.inIdealMode {
q.log.Debug("Expire: Not in ideal mode")
q.available <- struct{}{}
return
}
}

func (q *IPQueue) Expire() {
q.mu.Lock()
defer q.mu.Unlock()

q.log.Debug("Expire: In ideal mode")
defer func() {
Expand Down
36 changes: 19 additions & 17 deletions ipscanner/internal/ping/ping.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ping

import (
"context"
"errors"
"fmt"
"net/netip"
Expand All @@ -13,41 +14,41 @@ type Ping struct {
}

// DoPing performs a ping on the given IP address.
func (p *Ping) DoPing(ip netip.Addr) (statute.IPInfo, error) {
func (p *Ping) DoPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) {
if p.Options.SelectedOps&statute.HTTPPing > 0 {
res, err := p.httpPing(ip)
res, err := p.httpPing(ctx, ip)
if err != nil {
return statute.IPInfo{}, err
}

return res, nil
}
if p.Options.SelectedOps&statute.TLSPing > 0 {
res, err := p.tlsPing(ip)
res, err := p.tlsPing(ctx, ip)
if err != nil {
return statute.IPInfo{}, err
}

return res, nil
}
if p.Options.SelectedOps&statute.TCPPing > 0 {
res, err := p.tcpPing(ip)
res, err := p.tcpPing(ctx, ip)
if err != nil {
return statute.IPInfo{}, err
}

return res, nil
}
if p.Options.SelectedOps&statute.QUICPing > 0 {
res, err := p.quicPing(ip)
res, err := p.quicPing(ctx, ip)
if err != nil {
return statute.IPInfo{}, err
}

return res, nil
}
if p.Options.SelectedOps&statute.WARPPing > 0 {
res, err := p.warpPing(ip)
res, err := p.warpPing(ctx, ip)
if err != nil {
return statute.IPInfo{}, err
}
Expand All @@ -58,8 +59,9 @@ func (p *Ping) DoPing(ip netip.Addr) (statute.IPInfo, error) {
return statute.IPInfo{}, errors.New("no ping operation selected")
}

func (p *Ping) httpPing(ip netip.Addr) (statute.IPInfo, error) {
func (p *Ping) httpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) {
return p.calc(
ctx,
NewHttpPing(
ip,
"GET",
Expand All @@ -74,30 +76,30 @@ func (p *Ping) httpPing(ip netip.Addr) (statute.IPInfo, error) {
)
}

func (p *Ping) warpPing(ip netip.Addr) (statute.IPInfo, error) {
return p.calc(NewWarpPing(ip, p.Options))
func (p *Ping) warpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) {
return p.calc(ctx, NewWarpPing(ip, p.Options))
}

func (p *Ping) tlsPing(ip netip.Addr) (statute.IPInfo, error) {
return p.calc(
func (p *Ping) tlsPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) {
return p.calc(ctx,
NewTlsPing(ip, p.Options.Hostname, p.Options.Port, p.Options),
)
}

func (p *Ping) tcpPing(ip netip.Addr) (statute.IPInfo, error) {
return p.calc(
func (p *Ping) tcpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) {
return p.calc(ctx,
NewTcpPing(ip, p.Options.Hostname, p.Options.Port, p.Options),
)
}

func (p *Ping) quicPing(ip netip.Addr) (statute.IPInfo, error) {
return p.calc(
func (p *Ping) quicPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) {
return p.calc(ctx,
NewQuicPing(ip, p.Options.Hostname, p.Options.Port, p.Options),
)
}

func (p *Ping) calc(tp statute.IPing) (statute.IPInfo, error) {
pr := tp.Ping()
func (p *Ping) calc(ctx context.Context, tp statute.IPing) (statute.IPInfo, error) {
pr := tp.PingContext(ctx)
err := pr.Error()
if err != nil {
return statute.IPInfo{}, err
Expand Down
46 changes: 29 additions & 17 deletions ipscanner/internal/ping/warp.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ func (h *WarpPing) Ping() statute.IPingResult {
return h.PingContext(context.Background())
}

func (h *WarpPing) PingContext(_ context.Context) statute.IPingResult {
func (h *WarpPing) PingContext(ctx context.Context) statute.IPingResult {
addr := netip.AddrPortFrom(h.IP, warp.RandomWarpPort())
rtt, err := initiateHandshake(
ctx,
addr,
h.PrivateKey,
h.PeerPublicKey,
Expand Down Expand Up @@ -117,15 +118,21 @@ func ephemeralKeypair() (noise.DHKey, error) {
}, nil
}

func randomInt(min, max int) int {
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min+1)))
func randomInt(min, max uint64) uint64 {
rangee := max - min
if rangee < 1 {
return 0
}

n, err := rand.Int(rand.Reader, big.NewInt(int64(rangee)))
if err != nil {
panic(err)
}
return int(nBig.Int64()) + min

return min + n.Uint64()
}

func initiateHandshake(serverAddr netip.AddrPort, privateKeyBase64, peerPublicKeyBase64, presharedKeyBase64 string) (time.Duration, error) {
func initiateHandshake(ctx context.Context, serverAddr netip.AddrPort, privateKeyBase64, peerPublicKeyBase64, presharedKeyBase64 string) (time.Duration, error) {
staticKeyPair, err := staticKeypair(privateKeyBase64)
if err != nil {
return 0, err
Expand Down Expand Up @@ -209,19 +216,24 @@ func initiateHandshake(serverAddr netip.AddrPort, privateKeyBase64, peerPublicKe

numPackets := randomInt(8, 15)
randomPacket := make([]byte, 100)
for i := 0; i < numPackets; i++ {
packetSize := randomInt(40, 100)
_, err := rand.Read(randomPacket[:packetSize])
if err != nil {
return 0, fmt.Errorf("error generating random packet: %w", err)
for i := uint64(0); i < numPackets; i++ {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
packetSize := randomInt(40, 100)
_, err := rand.Read(randomPacket[:packetSize])
if err != nil {
return 0, fmt.Errorf("error generating random packet: %w", err)
}

_, err = conn.Write(randomPacket[:packetSize])
if err != nil {
return 0, fmt.Errorf("error sending random packet: %w", err)
}

time.Sleep(time.Duration(randomInt(20, 250)) * time.Millisecond)
}

_, err = conn.Write(randomPacket[:packetSize])
if err != nil {
return 0, fmt.Errorf("error sending random packet: %w", err)
}

time.Sleep(time.Duration(randomInt(20, 250)) * time.Millisecond)
}

_, err = initiationPacket.WriteTo(conn)
Expand Down

0 comments on commit 6c98f7a

Please sign in to comment.