Skip to content

Commit

Permalink
Merge pull request #49 from domdom82/cleanup-rules
Browse files Browse the repository at this point in the history
Cleanup old egress rules after applying a new blocking mode
  • Loading branch information
axel7born authored Oct 15, 2024
2 parents b9ec455 + 0285344 commit 6b25824
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 70 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ COPY . .
ARG TARGETARCH
RUN make build-filter-updater GOARCH=$TARGETARCH

FROM alpine:3.20.3 as builder
FROM alpine:3.20.3 AS builder

WORKDIR /volume

Expand Down Expand Up @@ -50,4 +50,4 @@ FROM scratch

COPY --from=builder /volume /

CMD /filter-updater
CMD ["/filter-updater"]
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ IMAGE_REPOSITORY := $(REGISTRY)/$(NAME)
IMAGE_TAG := $(VERSION)
EFFECTIVE_VERSION := $(VERSION)-$(shell git rev-parse HEAD)
GOARCH := amd64
PLATFORM := linux/amd64
REPO_ROOT := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))

TOOLS_DIR := hack/tools
Expand Down Expand Up @@ -53,7 +54,7 @@ build-filter-updater:

.PHONY: docker-images
docker-images:
@docker build -t $(IMAGE_REPOSITORY):$(IMAGE_TAG) -f Dockerfile --rm .
@docker build --platform $(PLATFORM) -t $(IMAGE_REPOSITORY):$(IMAGE_TAG) -f Dockerfile --rm .

.PHONY: release
release: docker-images docker-login docker-push
Expand Down
77 changes: 73 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ func updateFirewall(blockIngress bool, ipv4EgressFilterList, ipv6EgressFilterLis
defaultNetworkDevices := []string{defaultNetworkDeviceV4, defaultNetworkDeviceV6}

for i, v := range []string{"4", "6"} {
_ = netconfig.InitIPSet(v, ipSetNames[i])
err := netconfig.InitIPSet(v, ipSetNames[i])
if err != nil {
return fmt.Errorf("UpdateIPSet failed for %s: %v", ipSetNames[i], err)
}
egressFilterContent, err := os.ReadFile(filterLists[i])
if err != nil {
return fmt.Errorf("error reading egress filter list '%s': %v", filterLists[i], err)
Expand All @@ -67,6 +70,54 @@ func updateFirewall(blockIngress bool, ipv4EgressFilterList, ipv6EgressFilterLis
return nil
}

func cleanupBlackholeRoutes() error {
fmt.Println("Cleaning up blackhole routes...")
for _, v := range []string{"4", "6"} {
routes, err := netconfig.GetBlackholeRoutes(v)
if err != nil {
return err
}
fmt.Printf("cleaning up %d ipv%s routes\n", len(routes), v)
err = netconfig.DeleteRoutes(v, routes)
if err != nil {
return err
}
}

err := netconfig.RemoveDummyDevice()
return err
}

func cleanupFirewall() error {
fmt.Println("Cleaning up iptables rules...")

defaultNetworkDeviceV4, _ := netconfig.GetDefaultNetworkDevice("4")
defaultNetworkDeviceV6, _ := netconfig.GetDefaultNetworkDevice("6")

if defaultNetworkDeviceV4 == "" && defaultNetworkDeviceV6 == "" {
return fmt.Errorf("no default network device found")
} else if defaultNetworkDeviceV4 == "" {
defaultNetworkDeviceV4 = defaultNetworkDeviceV6
} else if defaultNetworkDeviceV6 == "" {
defaultNetworkDeviceV6 = defaultNetworkDeviceV4
}

ipSetNames := []string{ipv4IPSetName, ipv6IPSetName}
defaultNetworkDevices := []string{defaultNetworkDeviceV4, defaultNetworkDeviceV6}
for i, v := range []string{"4", "6"} {
err := netconfig.RemoveIPTablesLoggingRules(v, ipSetNames[i], defaultNetworkDevices[i])
if err != nil {
return fmt.Errorf("RemoveIPTablesLoggingRules failed for %s: %v", ipSetNames[i], err)
}
err = netconfig.RemoveIPSet(ipSetNames[i])
if err != nil {
return fmt.Errorf("RemoveIPSet failed for %s: %w", ipSetNames[i], err)
}
}

return nil
}

func main() {
var blackholing, blockIngress bool
var filterListDir, ipV4List, ipV6List string
Expand All @@ -85,27 +136,45 @@ func main() {
fmt.Printf("blackholing enabled: %v\n", blackholing)
for {
fmt.Println(time.Now())
_ = netconfig.InitLoggingChain("4")
_ = netconfig.InitLoggingChain("6")
err := netconfig.InitLoggingChain("4")
if err != nil {
fmt.Printf("Error initializing ipv4 logging chain: %v\n", err)
os.Exit(1)
}
err = netconfig.InitLoggingChain("6")
if err != nil {
fmt.Printf("Error initializing ipv6 logging chain: %v\n", err)
os.Exit(1)
}
if blackholing {
err := netconfig.InitDummyDevice()
if err != nil {
fmt.Printf("Error initializing dummy device: %v", err)
os.Exit(1)
}
fmt.Printf("Updating blackhole routes...")
fmt.Println("Updating blackhole routes...")
err = updateBlackholeRoutes(ipV4List, ipV6List)
if err != nil {
fmt.Fprintf(os.Stderr, "Error updating blackhole routes: %v\n", err)
os.Exit(1)
}
err = cleanupFirewall()
if err != nil {
fmt.Fprintf(os.Stderr, "Error cleaning up iptables: %v\n", err)
os.Exit(1)
}
} else {
fmt.Println("Updating iptables rules ...")
err := updateFirewall(blockIngress, ipV4List, ipV6List)
if err != nil {
fmt.Fprintf(os.Stderr, "Error updating iptables: %v\n", err)
os.Exit(1)
}
err = cleanupBlackholeRoutes()
if err != nil {
fmt.Fprintf(os.Stderr, "Error cleaning up blackhole routes: %v\n", err)
os.Exit(1)
}
}
fmt.Println(time.Now())
fmt.Printf("Going to sleep for %v...\n", sleepDuration)
Expand Down
112 changes: 84 additions & 28 deletions pkg/netconfig/netconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,24 @@ const (
tmpIPSet = "tmpIPSet"
)

type IPTablesAction string

const (
IPTablesAppend IPTablesAction = "-A"
IPTablesCheck IPTablesAction = "-C"
IPTablesDelete IPTablesAction = "-D"
)

var (
DefaultNetUtilsCommandExecutor NetUtilsCommandExecutor = &OSNetUtilsCommandExecutor{}
)

func GetDefaultNetworkDevice(ipVersion string) (string, error) {
out, err := DefaultNetUtilsCommandExecutor.ExecuteIPRouteCommand(ipVersion, "route", "show", "default")
output, err := DefaultNetUtilsCommandExecutor.ExecuteIPRouteCommand(ipVersion, "route", "show", "default")
if err != nil {
return "", err
}

output := out.String()
fields := strings.Fields(output)
for i, field := range fields {
if field == "dev" {
Expand Down Expand Up @@ -76,15 +83,12 @@ func InitLoggingChain(ipVersion string) error {

// firewaller

func IPTablesLoggingChainRule(ipVersion string, protocol string, ipSet string, device string, check bool, blockIngress bool) error {
action := "-A"
if check {
action = "-C"
}
func IPTablesLoggingChainRule(ipVersion string, protocol string, ipSet string, device string, action IPTablesAction, blockIngress bool) error {
a := string(action)

ipTablesArgs := []string{
"-t", "mangle",
action, "POSTROUTING",
a, "POSTROUTING",
"-o", device,
"-p", protocol,
"-m", "set",
Expand All @@ -102,18 +106,37 @@ func IPTablesLoggingChainRule(ipVersion string, protocol string, ipSet string, d

func AddIPTablesLoggingRules(ipVersion, ipSet, defaultNetworkDevice string, blockIngress bool) error {

if err := IPTablesLoggingChainRule(ipVersion, "tcp", ipSet, defaultNetworkDevice, true, blockIngress); err != nil {
err := IPTablesLoggingChainRule(ipVersion, "tcp", ipSet, defaultNetworkDevice, false, blockIngress)
if err != nil {
return fmt.Errorf("error creating tcp logging chain rules for %s, device %s %v", ipSet, defaultNetworkDevice, err)
for _, proto := range []string{"tcp", "udp"} {
if err := IPTablesLoggingChainRule(ipVersion, proto, ipSet, defaultNetworkDevice, IPTablesCheck, blockIngress); err != nil {
err := IPTablesLoggingChainRule(ipVersion, proto, ipSet, defaultNetworkDevice, IPTablesAppend, blockIngress)
if err != nil {
return fmt.Errorf("error creating %s logging chain rules for %s, device %s %w", proto, ipSet, defaultNetworkDevice, err)
}
}
}
if err := IPTablesLoggingChainRule(ipVersion, "udp", ipSet, defaultNetworkDevice, true, blockIngress); err != nil {
err := IPTablesLoggingChainRule(ipVersion, "udp", ipSet, defaultNetworkDevice, false, blockIngress)
if err != nil {
return fmt.Errorf("error creating udp logging chain rules for %s, device %s %v", ipSet, defaultNetworkDevice, err)
return nil
}

func RemoveIPTablesLoggingRules(ipVersion, ipSet, defaultNetworkDevice string) error {
// delete tcp rules. we don't care if SYN filtering was enabled previously. delete both variants.
for _, blockIngress := range []bool{true, false} {
if err := IPTablesLoggingChainRule(ipVersion, "tcp", ipSet, defaultNetworkDevice, IPTablesCheck, blockIngress); err != nil {
// rule does not exist; continue
continue
}
if err := IPTablesLoggingChainRule(ipVersion, "tcp", ipSet, defaultNetworkDevice, IPTablesDelete, blockIngress); err != nil {
return fmt.Errorf("error deleting tcp logging chain rules for %s, device %s, blockIngress %t, %w", ipSet, defaultNetworkDevice, blockIngress, err)
}
}
// delete udp rules.
if err := IPTablesLoggingChainRule(ipVersion, "udp", ipSet, defaultNetworkDevice, IPTablesCheck, false); err != nil {
// rule does not exist
return nil
}
if err := IPTablesLoggingChainRule(ipVersion, "udp", ipSet, defaultNetworkDevice, IPTablesDelete, false); err != nil {
return fmt.Errorf("error deleting udp logging chain rules for %s, device %s %w", ipSet, defaultNetworkDevice, err)
}
fmt.Printf("Removed iptables v%s rules for ipset %s on device %s\n", ipVersion, ipSet, defaultNetworkDevice)
return nil
}

Expand Down Expand Up @@ -173,7 +196,7 @@ func UpdateIPSet(ipVersion, ipSetName, egressFilterList, defaultNetworkDevice st
}

defer func() {
fmt.Println("Clean-up")
fmt.Println("Clean-up temporary ipset")
err := DefaultNetUtilsCommandExecutor.ExecuteIPSetCommand("destroy", tmpIPSet)
if err != nil {
fmt.Printf("Error cleaning-up temporary ipsets %v\n", err)
Expand Down Expand Up @@ -208,6 +231,18 @@ func UpdateIPSet(ipVersion, ipSetName, egressFilterList, defaultNetworkDevice st
return nil
}

func RemoveIPSet(ipSetName string) error {
if err := DefaultNetUtilsCommandExecutor.ExecuteIPSetCommand("list", ipSetName); err != nil {
return nil
}
if err := DefaultNetUtilsCommandExecutor.ExecuteIPSetCommand("destroy", ipSetName); err != nil {
return fmt.Errorf("error cleaning-up ipset %s: %w\n", ipSetName, err)
}

fmt.Printf("Removed ipset %s\n", ipSetName)
return nil
}

// blackholer

func diff(new, old []string) (added, removed []string) {
Expand Down Expand Up @@ -239,7 +274,7 @@ func diff(new, old []string) (added, removed []string) {

func InitDummyDevice() error {
out, _ := DefaultNetUtilsCommandExecutor.ExecuteIPRouteCommand("4", "link", "show")
if !strings.Contains(out.String(), " "+dummyDeviceName+": ") {
if !strings.Contains(out, " "+dummyDeviceName+": ") {
_, err := DefaultNetUtilsCommandExecutor.ExecuteIPRouteCommand("4", "link", "add", dummyDeviceName, "type", "dummy")
if err != nil {
return fmt.Errorf("error creating dummy device: %v", err)
Expand All @@ -251,20 +286,42 @@ func InitDummyDevice() error {
fmt.Println("Added dummy device.")
}

if err := DefaultNetUtilsCommandExecutor.ExecuteIPTablesCommand("4", "-t", "mangle", "-C", "POSTROUTING", "-o", dummyDeviceName, "-j", ipTablesLoggingChain); err != nil {
err = DefaultNetUtilsCommandExecutor.ExecuteIPTablesCommand("4", "-t", "mangle", "-A", "POSTROUTING", "-o", dummyDeviceName, "-j", ipTablesLoggingChain)
if err != nil {
return fmt.Errorf("error creating ip%stables rule for logging packets to dummy device: %v", "", err)
for _, ipv := range []string{"4", "6"} {
if err := DefaultNetUtilsCommandExecutor.ExecuteIPTablesCommand(ipv, "-t", "mangle", "-C", "POSTROUTING", "-o", dummyDeviceName, "-j", ipTablesLoggingChain); err != nil {
err = DefaultNetUtilsCommandExecutor.ExecuteIPTablesCommand(ipv, "-t", "mangle", "-A", "POSTROUTING", "-o", dummyDeviceName, "-j", ipTablesLoggingChain)
if err != nil {
return fmt.Errorf("error creating ip%stables rule for logging packets to dummy device: %v", "", err)
}
}
}

fmt.Println("Created iptables rules for logging packets to dummy device.")
return nil
}

func RemoveDummyDevice() error {
for _, ipv := range []string{"4", "6"} {
if err := DefaultNetUtilsCommandExecutor.ExecuteIPTablesCommand(ipv, "-t", "mangle", "-C", "POSTROUTING", "-o", dummyDeviceName, "-j", ipTablesLoggingChain); err != nil {
continue
}
if err := DefaultNetUtilsCommandExecutor.ExecuteIPTablesCommand(ipv, "-t", "mangle", "-D", "POSTROUTING", "-o", dummyDeviceName, "-j", ipTablesLoggingChain); err != nil {
return fmt.Errorf("error deleting ip%stables rule for logging packets to dummy device: %w", ipv, err)
}
}

if err := DefaultNetUtilsCommandExecutor.ExecuteIPTablesCommand("6", "-t", "mangle", "-C", "POSTROUTING", "-o", dummyDeviceName, "-j", ipTablesLoggingChain); err != nil {
err = DefaultNetUtilsCommandExecutor.ExecuteIPTablesCommand("6", "-t", "mangle", "-A", "POSTROUTING", "-o", dummyDeviceName, "-j", ipTablesLoggingChain)
out, _ := DefaultNetUtilsCommandExecutor.ExecuteIPRouteCommand("4", "link", "show")
if strings.Contains(out, " "+dummyDeviceName+": ") {
_, err := DefaultNetUtilsCommandExecutor.ExecuteIPRouteCommand("4", "link", "set", dummyDeviceName, "down")
if err != nil {
return fmt.Errorf("error creating ip%stables rule for logging packets to dummy device: %v", "6", err)
return fmt.Errorf("error bringing down dummy device: %w", err)
}
_, err = DefaultNetUtilsCommandExecutor.ExecuteIPRouteCommand("4", "link", "del", dummyDeviceName)
if err != nil {
return fmt.Errorf("error deleting dummy device: %w", err)
}
fmt.Println("Removed dummy device.")
}
fmt.Println("Created iptables rules for logging packets to dummy device.")

return nil
}

Expand All @@ -277,7 +334,7 @@ func GetBlackholeRoutes(ipVersion string) ([]string, error) {
return blackholeRoutes, err
}

lines := strings.Split(ipOut.String(), "\n")
lines := strings.Split(ipOut, "\n")
for _, line := range lines {
if strings.Contains(line, dummyDeviceName) && !strings.Contains(line, "fe80::/64") {
fields := strings.Fields(line)
Expand Down Expand Up @@ -331,7 +388,6 @@ func UpdateRoutes(ipVersion string, egressFilterList string) error {
if err != nil {
return err
}

fmt.Printf("Currently applied filter list contains %d entries\n", len(currentAddrs))

addAddr, delAddr := diff(newAddrs, currentAddrs)
Expand Down
Loading

0 comments on commit 6b25824

Please sign in to comment.