diff --git a/cmd/root.go b/cmd/root.go index fc6b742..6fbc025 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -32,5 +32,5 @@ func init() { logrus.SetLevel(logrus.DebugLevel) } } - rootCmd.PersistentFlags().StringVarP(&config, "config", "c", "/etc/wg-quick-op.yaml", "config file path") + rootCmd.PersistentFlags().StringVarP(&config, "config", "c", "/etc/wg-quick-op.toml", "config file path") } diff --git a/conf/config-sample.toml b/conf/config-sample.toml new file mode 100644 index 0000000..9eadfe2 --- /dev/null +++ b/conf/config-sample.toml @@ -0,0 +1,30 @@ +[start_on_boot] +enabled = true +# choose between skip and only, if both skipp and only are empty, all interfaces will be started +# if only_ifaces is not empty, skip_ifaces will be ignored +skip_ifaces = [] +#only_ifaces = [] + +[enhanced_dns.direct_resolver] +# resolve dns from direct NS server +enabled = true +# fetch ROA, config for direct_resolver +roa_finder = "223.5.5.5" + +[ddns] +enabled = true +# ddns check interval +interval = 60 +# when last handshake time is handshake_max seconds before now, treat it as offline +handshake_max = 150 +skip_ifaces = [] +#only_ifaces = [] + +# following configs are not implemented yet +#[openwrt] +#uci_iface = true +#namemap.tuntun = "tun00" +# +#[openwrt.firewall] +#default = 'dn11' +#fwmap.if0 = 'dn22' diff --git a/conf/config-sample.yaml b/conf/config-sample.yaml deleted file mode 100644 index 7664410..0000000 --- a/conf/config-sample.yaml +++ /dev/null @@ -1,14 +0,0 @@ -# up with system, be careful about routing table -enabled: -# - aaa -# - bbb - -# resolve hostname every interval -ddns: - # check every interval - interval: 60 - # when last handshake time is max-last-handshake before now, update interface - max_last_handshake: 150 - iface: -# - ccc -# - ddd diff --git a/conf/config.go b/conf/config.go index cdc8f09..4323849 100644 --- a/conf/config.go +++ b/conf/config.go @@ -8,16 +8,28 @@ import ( "time" ) -//go:embed config-sample.yaml +//go:embed config-sample.toml var configSample []byte var DDNS struct { - Interval time.Duration - Iface []string - MaxLastHandleShake time.Duration + Interval time.Duration + IfaceOnly []string + IfaceSkip []string + HandleShakeMax time.Duration } -var Enabled []string +var StartOnBoot struct { + Enabled bool + IfaceOnly []string + IfaceSkip []string +} + +var EnhancedDNS struct { + DirectResolver struct { + Enabled bool + ROAFinder string + } +} func Init(file string) { if _, err := os.Stat(file); err != nil { @@ -38,7 +50,7 @@ func Init(file string) { err := viper.ReadInConfig() viper.SetDefault("ddns.interval", 60) - viper.SetDefault("ddns.max_last_handshake", 150) + viper.SetDefault("ddns.handshake_max", 150) update() if err != nil { @@ -48,7 +60,14 @@ func Init(file string) { func update() { DDNS.Interval = time.Duration(viper.GetInt("ddns.interval")) * time.Second - DDNS.MaxLastHandleShake = time.Duration(viper.GetInt("ddns.max_last_handshake")) * time.Second - DDNS.Iface = viper.GetStringSlice("ddns.iface") - Enabled = viper.GetStringSlice("enabled") + DDNS.HandleShakeMax = time.Duration(viper.GetInt("ddns.handshake_max")) * time.Second + DDNS.IfaceOnly = viper.GetStringSlice("ddns.iface") + DDNS.IfaceSkip = viper.GetStringSlice("ddns.skip") + + StartOnBoot.Enabled = viper.GetBool("start_on_boot.enabled") + StartOnBoot.IfaceOnly = viper.GetStringSlice("start_on_boot.only_ifaces") + StartOnBoot.IfaceSkip = viper.GetStringSlice("start_on_boot.skip_ifaces") + + EnhancedDNS.DirectResolver.Enabled = viper.GetBool("enhanced_dns.direct_resolver.enabled") + EnhancedDNS.DirectResolver.ROAFinder = viper.GetString("enhanced_dns.direct_resolver.roa_finder") } diff --git a/conf/config_test.go b/conf/config_test.go new file mode 100644 index 0000000..8d6f72e --- /dev/null +++ b/conf/config_test.go @@ -0,0 +1,7 @@ +package conf + +import "testing" + +func TestParseConfig(t *testing.T) { + Init("config-sample.toml") +} diff --git a/daemon/service.go b/daemon/service.go index 75bee2e..b7b9050 100644 --- a/daemon/service.go +++ b/daemon/service.go @@ -3,7 +3,7 @@ package daemon import ( _ "embed" "errors" - "net" + "github.com/hdu-dn11/wg-quick-op/lib/dns" "os" "os/exec" "time" @@ -21,37 +21,13 @@ const ServicePath = "/etc/init.d/wg-quick-op" var ServiceFile []byte func Serve() { - for _, iface := range conf.Enabled { - iface := iface - cfg, err := quick.GetConfig(iface) - if err != nil { - logrus.WithField("iface", iface).WithError(err).Error("failed to get config") - continue - } - go func() { - if err := <-utils.Retry(10, func() error { - err := quick.Up(cfg, iface, logrus.WithField("iface", iface)) - if err == nil { - return nil - } - if errors.Is(err, os.ErrExist) { - logrus.WithField("iface", iface).Infoln("interface already up") - return nil - } - return err - }); err != nil { - logrus.WithField("iface", iface).WithError(err).Error("failed to up interface") - return - } - logrus.Infof("interface %s up", iface) - }() + if conf.StartOnBoot.Enabled { + startOnBoot() } - logrus.Infoln("all interface up") - // prepare config var cfgs []*ddns - for _, iface := range conf.DDNS.Iface { + for _, iface := range conf.DDNS.IfaceOnly { d, err := newDDNS(iface) if err != nil { logrus.WithField("iface", iface).WithError(err).Error("failed to init ddns config") @@ -76,7 +52,7 @@ func Serve() { logrus.WithField("iface", iface.name).WithField("peer", peer.PublicKey).Debugln("peer endpoint is nil, skip it") continue } - if time.Since(peer.LastHandshakeTime) < conf.DDNS.MaxLastHandleShake { + if time.Since(peer.LastHandshakeTime) < conf.DDNS.HandleShakeMax { logrus.WithField("iface", iface.name).WithField("peer", peer.PublicKey).Debugln("peer ok") continue } @@ -85,7 +61,7 @@ func Serve() { if !ok { continue } - addr, err := net.ResolveUDPAddr("", endpoint) + addr, err := dns.ResolveUDPAddr("", endpoint) if err != nil { logrus.WithField("iface", iface).WithField("peer", peer.PublicKey).WithError(err).Error("failed to resolve endpoint") continue @@ -122,6 +98,36 @@ func Serve() { } } +func startOnBoot() { + for _, iface := range utils.FindIface(conf.StartOnBoot.IfaceOnly, conf.StartOnBoot.IfaceSkip) { + iface := iface + cfg, err := quick.GetConfig(iface) + if err != nil { + logrus.WithField("iface", iface).WithError(err).Error("failed to get config") + continue + } + go func() { + if err := <-utils.Retry(5, func() error { + err := quick.Up(cfg, iface, logrus.WithField("iface", iface)) + if err == nil { + return nil + } + if errors.Is(err, os.ErrExist) { + logrus.WithField("iface", iface).Infoln("interface already up") + return nil + } + return err + }); err != nil { + logrus.WithField("iface", iface).WithError(err).Error("failed to up interface") + return + } + logrus.Infof("interface %s up", iface) + }() + } + + logrus.Infoln("all interface up") +} + func AddService() { _, err := exec.LookPath("wg-quick-op") if err != nil { diff --git a/go.mod b/go.mod index 11ab515..aa504f9 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/spf13/viper v1.18.1 github.com/stretchr/testify v1.8.4 github.com/vishvananda/netlink v1.1.0 - golang.org/x/sys v0.15.0 + golang.org/x/sys v0.18.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 ) @@ -23,6 +23,7 @@ require ( github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect + github.com/miekg/dns v1.1.59 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -36,11 +37,13 @@ require ( github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.org/x/crypto v0.16.0 // indirect + golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sync v0.5.0 // indirect + golang.org/x/mod v0.16.0 // indirect + golang.org/x/net v0.22.0 // indirect + golang.org/x/sync v0.6.0 // indirect golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.19.0 // indirect golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 468a1a8..1745690 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= +github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= @@ -73,18 +75,28 @@ go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= diff --git a/lib/dns/dns.go b/lib/dns/dns.go new file mode 100644 index 0000000..d433a7f --- /dev/null +++ b/lib/dns/dns.go @@ -0,0 +1,116 @@ +package dns + +import ( + "fmt" + "github.com/hdu-dn11/wg-quick-op/conf" + "github.com/miekg/dns" + "net" + "net/netip" + "strconv" +) + +var RoaFinder string + +func Init() { + if !conf.EnhancedDNS.DirectResolver.Enabled { + ResolveUDPAddr = net.ResolveUDPAddr + return + } + RoaFinder = conf.EnhancedDNS.DirectResolver.ROAFinder + if RoaFinder == "" { + config, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err == nil && len(config.Servers) > 0 { + RoaFinder = config.Servers[0] + } else { + RoaFinder = "223.5.5.5" + } + } + if _, err := netip.ParseAddr(RoaFinder); err != nil { + RoaFinder = net.JoinHostPort(RoaFinder, "53") + } +} + +var ResolveUDPAddr = func(network string, addr string) (*net.UDPAddr, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("split host port failed: %w", err) + } + + numport, err := strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("parse port failed: %w", err) + } + + ip, err := resolveIPAddr(host) + if err != nil { + return nil, fmt.Errorf("resolve ip addr failed: %w", err) + } + return &net.UDPAddr{IP: ip, Port: numport}, nil +} + +func resolveIPAddr(addr string) (net.IP, error) { + parsedAddr, err := netip.ParseAddr(addr) + if err == nil { + return net.IP(parsedAddr.AsSlice()).To16(), nil + } + + return directDNS(addr) +} + +func directDNS(addr string) (net.IP, error) { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(addr), dns.TypeSOA) + + c := new(dns.Client) + rec, _, err := c.Exchange(msg, RoaFinder) + if err != nil { + return nil, fmt.Errorf("write msg failed: %w", err) + } + + reply := append(rec.Answer, rec.Ns...) + + if len(reply) == 0 { + return nil, fmt.Errorf("no SOA record found") + } + + var NsServer string + for _, ans := range reply { + if a, ok := ans.(*dns.SOA); ok { + NsServer = a.Ns + break + } + } + if NsServer == "" { + return nil, fmt.Errorf("no SOA record found") + } + + for _, ans := range rec.Answer { + if a, ok := ans.(*dns.CNAME); ok { + addr = a.Target + break + } + } + + nsAddr := net.JoinHostPort(NsServer, "53") + msg.SetQuestion(dns.Fqdn(addr), dns.TypeA) + rec, _, err = c.Exchange(msg, nsAddr) + if err == nil { + for _, ans := range rec.Answer { + if a, ok := ans.(*dns.A); ok { + return a.A, nil + } + } + } + + msg.SetQuestion(dns.Fqdn(NsServer), dns.TypeAAAA) + rec, _, err = c.Exchange(msg, nsAddr) + if err == nil { + for _, ans := range rec.Answer { + if a, ok := ans.(*dns.AAAA); ok { + return a.AAAA, nil + } + } + } + + return nil, fmt.Errorf("no record found") +} diff --git a/lib/dns/dns_test.go b/lib/dns/dns_test.go new file mode 100644 index 0000000..8f3aba1 --- /dev/null +++ b/lib/dns/dns_test.go @@ -0,0 +1,29 @@ +package dns + +import "testing" + +func TestDirectDNS(t *testing.T) { + RoaFinder = "223.5.5.5:53" + testcases := []string{ + "www.baidu.com", + } + for _, testcase := range testcases { + ip, err := directDNS(testcase) + if err != nil { + t.Errorf("directDNS error:%v", err) + return + } + t.Log(ip) + + } +} + +func TestResolveUDP(t *testing.T) { + RoaFinder = "223.5.5.5:53" + addr, err := ResolveUDPAddr("", "baidu.com:12345") + if err != nil { + t.Errorf("ResolveUDPAddr error:%v", err) + return + } + t.Log(addr) +} diff --git a/quick/config.go b/quick/config.go index ce3f6c2..4841d74 100644 --- a/quick/config.go +++ b/quick/config.go @@ -5,6 +5,7 @@ import ( "encoding" "encoding/base64" "fmt" + "github.com/hdu-dn11/wg-quick-op/lib/dns" "github.com/sirupsen/logrus" "net" "os" @@ -393,7 +394,7 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error { peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask}) } case "Endpoint": - addr, err := net.ResolveUDPAddr("", rhs) + addr, err := dns.ResolveUDPAddr("", rhs) if err != nil { return err } diff --git a/utils/utils.go b/utils/utils.go index 8d42783..8434df6 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,6 +1,12 @@ package utils -import "time" +import ( + "github.com/sirupsen/logrus" + "os" + "slices" + "strings" + "time" +) func Retry(times int, f func() error) <-chan error { done := make(chan error) @@ -23,3 +29,24 @@ func Retry(times int, f func() error) <-chan error { }() return done } + +func FindIface(only []string, skip []string) []string { + if only != nil { + return only + } + + var ifaceList []string + entry, err := os.ReadDir("/etc/wireguard") + if err != nil { + logrus.WithError(err).Errorln("read dir /etc/wireguard failed when find iface") + return nil + } + for _, v := range entry { + name := strings.TrimSuffix(v.Name(), ".conf") + if slices.Index(skip, name) != -1 { + continue + } + ifaceList = append(ifaceList, name) + } + return ifaceList +}