Skip to content

Commit

Permalink
Merge pull request #123 from apernet/wip-lookup
Browse files Browse the repository at this point in the history
feat: dns lookup function
  • Loading branch information
tobyxdd authored Apr 8, 2024
2 parents d7737e9 + 9c0893c commit 393c29b
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 110 deletions.
16 changes: 6 additions & 10 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error {
if err != nil {
return configError{Field: "io", Err: err}
}
config.IOs = []io.PacketIO{nfio}
config.IO = nfio
return nil
}

Expand Down Expand Up @@ -247,22 +247,18 @@ func runMain(cmd *cobra.Command, args []string) {
if err != nil {
logger.Fatal("failed to parse config", zap.Error(err))
}
defer func() {
// Make sure to close all IOs on exit
for _, i := range engineConfig.IOs {
_ = i.Close()
}
}()
defer engineConfig.IO.Close() // Make sure to close IO on exit

// Ruleset
rawRs, err := ruleset.ExprRulesFromYAML(args[0])
if err != nil {
logger.Fatal("failed to load rules", zap.Error(err))
}
rsConfig := &ruleset.BuiltinConfig{
Logger: &rulesetLogger{},
GeoSiteFilename: config.Ruleset.GeoSite,
GeoIpFilename: config.Ruleset.GeoIp,
Logger: &rulesetLogger{},
GeoSiteFilename: config.Ruleset.GeoSite,
GeoIpFilename: config.Ruleset.GeoIp,
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
}
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
if err != nil {
Expand Down
34 changes: 15 additions & 19 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)

type engine struct {
logger Logger
ioList []io.PacketIO
io io.PacketIO
workers []*worker
}

Expand All @@ -42,7 +42,7 @@ func NewEngine(config Config) (Engine, error) {
}
return &engine{
logger: config.Logger,
ioList: config.IOs,
io: config.IO,
workers: workers,
}, nil
}
Expand All @@ -58,27 +58,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {

func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IOs
defer ioCancel() // Stop workers & IO

// Start workers
for _, w := range e.workers {
go w.Run(ioCtx)
}

// Register callbacks
errChan := make(chan error, len(e.ioList))
for _, i := range e.ioList {
ioEntry := i // Make sure dispatch() uses the correct ioEntry
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(ioEntry, p)
})
// Register IO callback
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
return err
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}

// Block until IO errors or context is cancelled
Expand All @@ -91,8 +88,7 @@ func (e *engine) Run(ctx context.Context) error {
}

// dispatch dispatches a packet to a worker.
// This must be safe for concurrent use, as it may be called from multiple IOs.
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
func (e *engine) dispatch(p io.Packet) bool {
data := p.Data()
ipVersion := data[0] >> 4
var layerType gopacket.LayerType
Expand All @@ -102,7 +98,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
layerType = layers.LayerTypeIPv6
} else {
// Unsupported network layer
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil)
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
// Load balance by stream ID
Expand All @@ -112,7 +108,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
StreamID: p.StreamID(),
Packet: packet,
SetVerdict: func(v io.Verdict, b []byte) error {
return ioEntry.SetVerdict(p, v, b)
return e.io.SetVerdict(p, v, b)
},
})
return true
Expand Down
2 changes: 1 addition & 1 deletion engine/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Engine interface {
// Config is the configuration for the engine.
type Config struct {
Logger Logger
IOs []io.PacketIO
IO io.PacketIO
Ruleset ruleset.Ruleset

Workers int // Number of workers. Zero or negative means auto (number of CPU cores).
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.21
require (
github.com/bwmarrin/snowflake v0.3.0
github.com/coreos/go-iptables v0.7.0
github.com/expr-lang/expr v1.15.7
github.com/expr-lang/expr v1.16.3
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866
github.com/hashicorp/golang-lru/v2 v2.0.7
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/expr-lang/expr v1.15.7 h1:BK0JcWUkoW6nrbLBo6xCKhz4BvH5DSOOu1Gx5lucyZo=
github.com/expr-lang/expr v1.15.7/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
github.com/expr-lang/expr v1.16.3 h1:NLldf786GffptcXNxxJx5dQ+FzeWDKChBDqOOwyK8to=
github.com/expr-lang/expr v1.16.3/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow=
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
Expand Down
6 changes: 5 additions & 1 deletion io/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io

import (
"context"
"net"
)

type Verdict int
Expand Down Expand Up @@ -29,7 +30,6 @@ type Packet interface {

// PacketCallback is called for each packet received.
// Return false to "unregister" and stop receiving packets.
// It must be safe for concurrent use.
type PacketCallback func(Packet, error) bool

type PacketIO interface {
Expand All @@ -39,6 +39,10 @@ type PacketIO interface {
Register(context.Context, PacketCallback) error
// SetVerdict sets the verdict for a packet.
SetVerdict(Packet, Verdict, []byte) error
// ProtectedDialContext is like net.DialContext, but the connection is "protected"
// in the sense that the packets sent/received through the connection must bypass
// the packet IO and not be processed by the callback.
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
// Close closes the packet IO.
Close() error
}
Expand Down
23 changes: 23 additions & 0 deletions io/nfqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"encoding/binary"
"errors"
"fmt"
"net"
"os/exec"
"strconv"
"strings"
"syscall"

"github.com/coreos/go-iptables/iptables"
"github.com/florianl/go-nfqueue"
Expand Down Expand Up @@ -50,6 +52,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
}
for i := range table.Chains {
c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
Expand All @@ -72,6 +75,8 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
}
rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains {
// Bypass protected connections
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
Expand All @@ -96,6 +101,8 @@ type nfqueuePacketIO struct {
// iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables

protectedDialer *net.Dialer
}

type NFQueuePacketIOConfig struct {
Expand Down Expand Up @@ -153,6 +160,18 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
cErr := c.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
})
if cErr != nil {
return cErr
}
return err
},
},
}, nil
}

Expand Down Expand Up @@ -239,6 +258,10 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
}
}

func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n.protectedDialer.DialContext(ctx, network, address)
}

func (n *nfqueuePacketIO) Close() error {
if n.rSet {
if n.ipt4 != nil {
Expand Down
8 changes: 3 additions & 5 deletions ruleset/builtins/geo/geo_matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ type GeoMatcher struct {
ipMatcherLock sync.Mutex
}

func NewGeoMatcher(geoSiteFilename, geoIpFilename string) (*GeoMatcher, error) {
geoLoader := NewDefaultGeoLoader(geoSiteFilename, geoIpFilename)

func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
return &GeoMatcher{
geoLoader: geoLoader,
geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
geoSiteMatcher: make(map[string]hostMatcher),
geoIpMatcher: make(map[string]hostMatcher),
}, nil
}
}

func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
Expand Down
Loading

0 comments on commit 393c29b

Please sign in to comment.