Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/filter-with-rules #49

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
483 changes: 483 additions & 0 deletions filter/acls_parser.go

Large diffs are not rendered by default.

160 changes: 160 additions & 0 deletions filter/acls_parser_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package filter

import (
"strings"
"testing"
)

func TestParseIpv4(t *testing.T) {
ip, valid := parseIpv4("1.2.3.4")
if !valid {
t.Error("parseIpv4 says 1.2.3.4 is invalid")
} else {
if ip != toIpv4(1, 2, 3, 4) {
t.Error("parseIpv4 did not parse 1.2.3.4 properly")
}
}

ip, valid = parseIpv4("0.0.0.0")
if !valid {
t.Error("parseIpv4 says 0.0.0.0 is invalid")
} else {
if ip != toIpv4(0, 0, 0, 0) {
t.Error("parseIpv4 did not parse 0.0.0.0 properly")
}
}

}

func TestGetSource(t *testing.T) {
net, mask, tokens, ok := getSource([]string{"host", "192.168.69.3"})
if !ok {
t.Fatal("getSource failed")
}

if net != toIpv4(192, 168, 69, 3) {
t.Fatal("getSource network is invalid")
}

if mask != toIpv4(255, 255, 255, 255) {
t.Fatal("getSource mask is invalid")
}

if len(tokens) != 0 {
t.Fatal("getSource tokens is invalid")
}

}

func TestGetProtocol(t *testing.T) {
protocol, tokens, ok := getProtocol([]string{"ip", "any", "any"})
if !ok {
t.Fatal("getProtocol failed")
}

if protocol != allProtocols {
t.Fatal("getProtocol did not parse protocol")
}

if len(tokens) != 2 {
t.Fatal("getProtocol did not parse tokens")
}

protocol, tokens, ok = getProtocol([]string{"54", "any", "any"})
if !ok {
t.Fatal("getProtocol failed")
}

if protocol != 54 {
t.Fatal("getProtocol did not parse protocol")
}

if len(tokens) != 2 {
t.Fatal("getProtocol did not parse tokens")
}

}

type getPortExpected struct {
op int
tokens []string
startPort uint16
endPort uint16
ok bool
}

func TestGetPort(t *testing.T) {
tests := []struct {
protocol int16
input string
expected getPortExpected
}{
{6, "eq 123 any", getPortExpected{portOpEq, []string{"any"}, 123, 123, true}},
{6, "neq 123 any", getPortExpected{portOpNeq, []string{"any"}, 123, 123, true}},
{6, "gt 123 any", getPortExpected{portOpGt, []string{"any"}, 123, 123, true}},
{6, "lt 123 any", getPortExpected{portOpLt, []string{"any"}, 123, 123, true}},
{6, "range 123 125 any", getPortExpected{portOpRange, []string{"any"}, 123, 125, true}},
{6, "range 123 any", getPortExpected{invalidOp, []string{"range", "123", "any"}, 0, 0, false}},
{6, "eq any", getPortExpected{invalidOp, []string{"eq", "any"}, 0, 0, false}},
}

for _, test := range tests {
op, startPort, endPort, tokens, ok := getPort(test.protocol, strings.Fields(test.input))
e := test.expected

if ok != e.ok {
t.Fatalf("Expected value for ok failed for '%s'", test.input)
}

if op != e.op {
t.Fatalf("Expected value for op failed for '%s' expected '%d' got '%d'", test.input, e.op, op)
}

if startPort != e.startPort {
t.Fatalf("Expected value for startPort failed for '%s' expected '%d' got '%d'", test.input, e.startPort, startPort)
}

if endPort != e.endPort {
t.Fatalf("Expected value for endPort failed for '%s' expected '%d' got '%d'", test.input, e.endPort, endPort)
}

if len(tokens) != len(e.tokens) {
t.Fatalf("Expected value for tokens failed for '%s' expected '%d' got '%d'", test.input, len(e.tokens), len(tokens))
}

}

}

func TestGetIcmpRule(t *testing.T) {
tests := []struct {
acl string
rule icmpRule
tokens []string
ok bool
}{
{"echo", icmpRule{8, -1}, []string{}, true},
{"echo 5", icmpRule{8, 5}, []string{}, true},
{"bob", icmpRule{-1, -1}, []string{"bob"}, false},
{"echo bob", icmpRule{8, -1}, []string{"bob"}, true},
}
for _, test := range tests {
acl := test.acl
icmpRule, tokens, ok := getIcmpRule(strings.Fields(acl))
if ok != test.ok {
t.Fatalf("Expected value for ok failed for '%s'", acl)
}

if icmpRule.iType != test.rule.iType {
t.Fatalf("Expected value for iType failed for '%s' expected '%d' got '%d'", acl, test.rule.iType, icmpRule.iType)
}

if icmpRule.code != test.rule.code {
t.Fatalf("Expected value for code failed for '%s' expected '%d' got '%d'", acl, test.rule.code, icmpRule.code)
}

if len(tokens) != len(test.tokens) {
t.Fatalf("Expected value for len tokens failed for '%s' expected '%d' got '%d'", acl, len(test.tokens), len(tokens))
}
}
}
7 changes: 7 additions & 0 deletions filter/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package filter

import (
"errors"
)

var ErrDenyAll = errors.New("Deny All")
3 changes: 2 additions & 1 deletion filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package filter

import (
"errors"
"github.com/inverse-inc/wireguard-go/services"
"strconv"
"strings"

"github.com/inverse-inc/wireguard-go/services"
)

const (
Expand Down
97 changes: 97 additions & 0 deletions filter/icmp_packets_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package filter

var icmpPacket1 = []byte{
0x45, 0x00,
0x00, 0x3c, 0xd7, 0x43, 0x00, 0x00, 0x80, 0x01,
0x2b, 0x73, 0xc0, 0xa8, 0x9e, 0x8b, 0xae, 0x89,
0x2a, 0x4d, 0x08, 0x00, 0x2a, 0x5c, 0x02, 0x00,
0x21, 0x00, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69,
}

var icmpPacket2 = []byte{
0x45, 0x00,
0x00, 0x3c, 0x76, 0xe1, 0x00, 0x00, 0x80, 0x01,
0x8b, 0xd5, 0xae, 0x89, 0x2a, 0x4d, 0xc0, 0xa8,
0x9e, 0x8b, 0x00, 0x00, 0x32, 0x5c, 0x02, 0x00,
0x21, 0x00, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69,
}

var icmpPacket3 = []byte{
0x45, 0x00,
0x00, 0x3c, 0xd7, 0x46, 0x00, 0x00, 0x80, 0x01,
0x2b, 0x70, 0xc0, 0xa8, 0x9e, 0x8b, 0xae, 0x89,
0x2a, 0x4d, 0x08, 0x00, 0x29, 0x5c, 0x02, 0x00,
0x22, 0x00, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69,
}

var icmpPacket4 = []byte{
0x45, 0x00,
0x00, 0x3c, 0x76, 0xe4, 0x00, 0x00, 0x80, 0x01,
0x8b, 0xd2, 0xae, 0x89, 0x2a, 0x4d, 0xc0, 0xa8,
0x9e, 0x8b, 0x00, 0x00, 0x31, 0x5c, 0x02, 0x00,
0x22, 0x00, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69,
}

var icmpPacket5 = []byte{
0x45, 0x00,
0x00, 0x3c, 0xd7, 0x49, 0x00, 0x00, 0x80, 0x01,
0x2b, 0x6d, 0xc0, 0xa8, 0x9e, 0x8b, 0xae, 0x89,
0x2a, 0x4d, 0x08, 0x00, 0x28, 0x5c, 0x02, 0x00,
0x23, 0x00, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69,
}

var icmpPacket6 = []byte{
0x45, 0x00,
0x00, 0x3c, 0x76, 0xf0, 0x00, 0x00, 0x80, 0x01,
0x8b, 0xc6, 0xae, 0x89, 0x2a, 0x4d, 0xc0, 0xa8,
0x9e, 0x8b, 0x00, 0x00, 0x30, 0x5c, 0x02, 0x00,
0x23, 0x00, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69,
}

var icmpPacket7 = []byte{
0x45, 0x00,
0x00, 0x3c, 0xd7, 0x4e, 0x00, 0x00, 0x80, 0x01,
0x2b, 0x68, 0xc0, 0xa8, 0x9e, 0x8b, 0xae, 0x89,
0x2a, 0x4d, 0x08, 0x00, 0x27, 0x5c, 0x02, 0x00,
0x24, 0x00, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69,
}

var icmpPacket8 = []byte{
0x45, 0x00,
0x00, 0x3c, 0x76, 0xf5, 0x00, 0x00, 0x80, 0x01,
0x8b, 0xc1, 0xae, 0x89, 0x2a, 0x4d, 0xc0, 0xa8,
0x9e, 0x8b, 0x00, 0x00, 0x2f, 0x5c, 0x02, 0x00,
0x24, 0x00, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69,
}
64 changes: 64 additions & 0 deletions filter/rbac_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package filter

import (
"context"
"fmt"
"time"

"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/inverse-inc/packetfence/go/sharedutils"
"github.com/inverse-inc/packetfence/go/unifiedapiclient"
"github.com/inverse-inc/wireguard-go/device"
"github.com/inverse-inc/wireguard-go/ztn"
"github.com/patrickmn/go-cache"
)

var rbacAllowCache *cache.Cache

func init() {
var cacheTime time.Duration
defaultCacheTime := 30 * time.Second
cacheTimeEnv := sharedutils.EnvOrDefault(ztn.EnvRBACIPFilteringCacheTime, defaultCacheTime.String())
if cacheTimeParsed, err := time.ParseDuration(cacheTimeEnv); err == nil {
cacheTime = cacheTimeParsed
} else {
fmt.Println("Unable to parse", ztn.EnvRBACIPFilteringCacheTime, err)
cacheTime = defaultCacheTime
}
rbacAllowCache = cache.New(cacheTime, 1*time.Minute)
}

func BuildRBACFilter(apiClientCtx context.Context, apiClient *unifiedapiclient.Client, logger *device.Logger) RuleFunc {
return func(data []byte) RuleCmd {
packet := gopacket.NewPacket(data, layers.LayerTypeIPv4, gopacket.Default)
if ip4Layer := packet.Layer(layers.LayerTypeIPv4); ip4Layer != nil {
ip4 := ip4Layer.(*layers.IPv4)
k := fmt.Sprintf("%s->%s", ip4.SrcIP, ip4.DstIP)
if cacheRes, found := rbacAllowCache.Get(k); found {
return cacheRes.(RuleCmd)
} else {
apiRes := struct {
Permit bool `json:"permit"`
Reason string `json:"reason"`
}{}
err := apiClient.Call(apiClientCtx, "GET", fmt.Sprintf("/api/v1/remote_clients/allowed_ip_communication?src_ip=%s&dst_ip=%s", ip4.SrcIP, ip4.DstIP), &apiRes)
var res RuleCmd
if err != nil {
logger.Info.Println("(API ERR) Denying access from", ip4.SrcIP, "to", ip4.DstIP)
res = Deny
} else if !apiRes.Permit {
logger.Info.Println("(Access Denied) Denying access from", ip4.SrcIP, "to", ip4.DstIP, apiRes.Reason)
res = Deny
} else {
logger.Info.Println("Allowing access from", ip4.SrcIP, "to", ip4.DstIP)
res = Permit
}
rbacAllowCache.SetDefault(k, res)
return res
}
} else {
return Permit
}
}
}
Loading