From 3761c048bd17137be5921399bff3b4baeae58724 Mon Sep 17 00:00:00 2001 From: Sultan <37975139+AitakattaSora@users.noreply.github.com> Date: Fri, 20 Dec 2024 20:02:26 +0500 Subject: [PATCH] Cosmetic rules (#168) * Add cosmetic rules store with tests * Add support for cosmetic hiding rules * Refactor cosmetic rule handling, remove cosmetic rules from ignoreLineRegex * Rewrite cosmetic rule store to triestore for better wildcards support, change ignoreLineRegexp to match hosts comments * Optimize CSS injection by batching selectors for improved performance * Use htmlrewrite package to replace head contents * Sanitize css selecter when adding cosmetic rules * Refactor CSS selector sanitization and improve error handling --- internal/app/app.go | 10 +- internal/cosmetic/addrule.go | 55 ++++++ internal/cosmetic/injector.go | 77 ++++++++ internal/cosmetic/sanitizer.go | 168 ++++++++++++++++++ internal/cosmetic/sanitizer_test.go | 97 ++++++++++ internal/cosmetic/triestore/triestore.go | 93 ++++++++++ internal/cosmetic/triestore/triestore_test.go | 108 +++++++++++ internal/filter/filter.go | 50 ++++-- 8 files changed, 644 insertions(+), 14 deletions(-) create mode 100644 internal/cosmetic/addrule.go create mode 100644 internal/cosmetic/injector.go create mode 100644 internal/cosmetic/sanitizer.go create mode 100644 internal/cosmetic/sanitizer_test.go create mode 100644 internal/cosmetic/triestore/triestore.go create mode 100644 internal/cosmetic/triestore/triestore_test.go diff --git a/internal/app/app.go b/internal/app/app.go index b0704b90..aa9984ee 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -14,6 +14,8 @@ import ( "github.com/anfragment/zen/internal/certgen" "github.com/anfragment/zen/internal/certstore" "github.com/anfragment/zen/internal/cfg" + "github.com/anfragment/zen/internal/cosmetic" + cosmeticTrieStore "github.com/anfragment/zen/internal/cosmetic/triestore" "github.com/anfragment/zen/internal/filter" "github.com/anfragment/zen/internal/jsrule" "github.com/anfragment/zen/internal/logger" @@ -172,9 +174,15 @@ func (a *App) StartProxy() (err error) { return fmt.Errorf("create scriptlets injector: %v", err) } + cosmeticRulesStore := cosmeticTrieStore.NewTrieStore() + cosmeticRulesInjector, err := cosmetic.NewInjector(cosmeticRulesStore) + if err != nil { + return fmt.Errorf("create cosmetic rules injector: %v", err) + } + jsRuleInjector := jsrule.NewInjector() - filter, err := filter.NewFilter(a.config, ruleMatcher, exceptionRuleMatcher, scriptletInjector, jsRuleInjector, a.eventsHandler) + filter, err := filter.NewFilter(a.config, ruleMatcher, exceptionRuleMatcher, scriptletInjector, cosmeticRulesInjector, jsRuleInjector, a.eventsHandler) if err != nil { return fmt.Errorf("create filter: %v", err) } diff --git a/internal/cosmetic/addrule.go b/internal/cosmetic/addrule.go new file mode 100644 index 00000000..a2ccf276 --- /dev/null +++ b/internal/cosmetic/addrule.go @@ -0,0 +1,55 @@ +package cosmetic + +import ( + "errors" + "fmt" + "net" + "regexp" + "strings" +) + +var ( + // RuleRegex matches cosmetic rules. + RuleRegex = regexp.MustCompile(`^(?:([^#$]+?)##|##)(.+)$`) + + errUnsupportedSyntax = errors.New("unsupported syntax") +) + +func (inj *Injector) AddRule(rule string) error { + + var rawHostnames string + var selector string + + if match := RuleRegex.FindStringSubmatch(rule); match != nil { + rawHostnames = match[1] + selector = match[2] + } else { + return errUnsupportedSyntax + } + + sanitizedSelector, err := sanitizeCSSSelector(selector) + if err != nil { + return fmt.Errorf("failed to sanitize selector: %w", err) + } + + if len(rawHostnames) == 0 { + inj.store.Add(nil, sanitizedSelector) + return nil + } + + hostnames := strings.Split(rawHostnames, ",") + subdomainHostnames := make([]string, 0, len(hostnames)) + for _, hostname := range hostnames { + if len(hostname) == 0 { + return errors.New("empty hostnames are not allowed") + } + + if net.ParseIP(hostname) == nil && !strings.HasPrefix(hostname, "*.") { + subdomainHostnames = append(subdomainHostnames, "*."+hostname) + } + } + inj.store.Add(hostnames, sanitizedSelector) + inj.store.Add(subdomainHostnames, sanitizedSelector) + + return nil +} diff --git a/internal/cosmetic/injector.go b/internal/cosmetic/injector.go new file mode 100644 index 00000000..b485fd5e --- /dev/null +++ b/internal/cosmetic/injector.go @@ -0,0 +1,77 @@ +package cosmetic + +import ( + "bytes" + "errors" + "fmt" + "log" + "net/http" + "strings" + + "github.com/anfragment/zen/internal/htmlrewrite" + "github.com/anfragment/zen/internal/logger" +) + +var ( + styleOpeningTag = []byte("") +) + +type Injector struct { + // store stores and retrieves css by hostname. + store Store +} + +type Store interface { + Add(hostnames []string, selector string) + Get(hostname string) []string +} + +func NewInjector(store Store) (*Injector, error) { + if store == nil { + return nil, errors.New("store is nil") + } + + return &Injector{ + store: store, + }, nil +} + +func (inj *Injector) Inject(req *http.Request, res *http.Response) error { + hostname := req.URL.Hostname() + selectors := inj.store.Get(hostname) + log.Printf("got %d cosmetic rules for %q", len(selectors), logger.Redacted(hostname)) + if len(selectors) == 0 { + return nil + } + + var ruleInjection bytes.Buffer + ruleInjection.Write(styleOpeningTag) + css := generateBatchedCSS(selectors) + ruleInjection.WriteString(css) + ruleInjection.Write(styleClosingTag) + + htmlrewrite.ReplaceHeadContents(res, func(match []byte) []byte { + return bytes.Join([][]byte{match, ruleInjection.Bytes()}, nil) + }) + + return nil +} + +func generateBatchedCSS(selectors []string) string { + const batchSize = 100 + + var builder strings.Builder + for i := 0; i < len(selectors); i += batchSize { + end := i + batchSize + if end > len(selectors) { + end = len(selectors) + } + batch := selectors[i:end] + + joinedSelectors := strings.Join(batch, ",") + builder.WriteString(fmt.Sprintf("%s{display:none!important;}", joinedSelectors)) + } + + return builder.String() +} diff --git a/internal/cosmetic/sanitizer.go b/internal/cosmetic/sanitizer.go new file mode 100644 index 00000000..345640e9 --- /dev/null +++ b/internal/cosmetic/sanitizer.go @@ -0,0 +1,168 @@ +package cosmetic + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +// sanitizeCSSSelector validates and sanitizes a CSS selector. +func sanitizeCSSSelector(selectorInput string) (string, error) { + if strings.Contains(selectorInput, "") { + return "", errors.New("selector contains '' which is not allowed") + } + + selector := decodeUnicodeEscapes(selectorInput) + if !hasBalancedQuotesAndBrackets(selector) { + return "", errors.New("selector has unbalanced quotes or brackets") + } + + if err := validateSelector(selector); err != nil { + return "", fmt.Errorf("sanitize selector: %w", err) + } + + return selector, nil +} + +// decodeUnicodeEscapes replaces CSS Unicode escapes with their actual characters. +func decodeUnicodeEscapes(s string) string { + re := regexp.MustCompile(`\\([0-9A-Fa-f]{1,6})(\s)?`) + return re.ReplaceAllStringFunc(s, func(match string) string { + submatches := re.FindStringSubmatch(match) + if len(submatches) < 2 { + return match + } + hexDigits := submatches[1] + r, err := strconv.ParseInt(hexDigits, 16, 32) + if err != nil { + return match + } + return string(rune(r)) + }) +} + +// hasBalancedQuotesAndBrackets checks for balanced quotes and brackets in the selector. +func hasBalancedQuotesAndBrackets(s string) bool { + var stack []rune + inSingleQuote := false + inDoubleQuote := false + escaped := false + + for _, c := range s { + if escaped { + escaped = false + continue + } + + if c == '\\' { + escaped = true + continue + } + + if inSingleQuote { + if c == '\'' { + inSingleQuote = false + } + continue + } + + if inDoubleQuote { + if c == '"' { + inDoubleQuote = false + } + continue + } + + if c == '\'' { + inSingleQuote = true + continue + } + + if c == '"' { + inDoubleQuote = true + continue + } + + if c == '(' || c == '[' || c == '{' { + stack = append(stack, c) + } else if c == ')' || c == ']' || c == '}' { + if len(stack) == 0 { + return false + } + last := stack[len(stack)-1] + if (c == ')' && last != '(') || + (c == ']' && last != '[') || + (c == '}' && last != '{') { + return false + } + stack = stack[:len(stack)-1] + } + } + + return !inSingleQuote && !inDoubleQuote && len(stack) == 0 && !escaped +} + +// validateSelector checks for dangerous sequences in the selector. +func validateSelector(s string) error { + inSingleQuote := false + inDoubleQuote := false + escaped := false + runes := []rune(s) + + for i := 0; i < len(runes); i++ { + c := runes[i] + + if escaped { + escaped = false + continue + } + + if c == '\\' { + escaped = true + continue + } + + if inSingleQuote { + if c == '\'' { + inSingleQuote = false + } + continue + } + + if inDoubleQuote { + if c == '"' { + inDoubleQuote = false + } + continue + } + + if c == '\'' { + inSingleQuote = true + continue + } + + if c == '"' { + inDoubleQuote = true + continue + } + + if !inSingleQuote && !inDoubleQuote { + // Check for dangerous sequences. + if c == '/' && i+1 < len(runes) && runes[i+1] == '*' { + return errors.New("found '/*' outside of quotes") + } + + if c == '*' && i+1 < len(runes) && runes[i+1] == '/' { + return errors.New("found '*/' outside of quotes") + } + + if c == '{' || c == '}' || c == ';' || c == '@' { + return fmt.Errorf("found dangerous character '%c' outside of quotes", c) + } + } + } + + return nil +} diff --git a/internal/cosmetic/sanitizer_test.go b/internal/cosmetic/sanitizer_test.go new file mode 100644 index 00000000..8072c754 --- /dev/null +++ b/internal/cosmetic/sanitizer_test.go @@ -0,0 +1,97 @@ +package cosmetic + +import ( + "testing" +) + +func TestSanitizer(t *testing.T) { + t.Parallel() + + t.Run("simple selector is not sanitized", func(t *testing.T) { + t.Parallel() + + selector := "body" + sanitized, err := sanitizeCSSSelector(selector) + if err != nil { + t.Fatal(err) + } + + if sanitized != "body" { + t.Errorf("expected %q, got %q", selector, sanitized) + } + }) + + t.Run("Valid Complex Selectors", func(t *testing.T) { + selector := `body > div[id^="ai-adb-"][style^="position: fixed; top:"][style*="z-index: 9999"]:hover` + if _, err := sanitizeCSSSelector(selector); err != nil { + t.Errorf("Expected valid selector, got error: %v", err) + } + + selector = `a[href^="https://"][data-info="some:info"][class~="button active"]` + if _, err := sanitizeCSSSelector(selector); err != nil { + t.Errorf("Expected valid selector, got error: %v", err) + } + + selector = `ul.menu > li[class*="dropdown"] ul li a[href*="contact"]` + if _, err := sanitizeCSSSelector(selector); err != nil { + t.Errorf("Expected valid selector, got error: %v", err) + } + }) + + t.Run("Dangerous Characters Outside Quotes", func(t *testing.T) { + selector := `div } body { color: red; }` + if _, err := sanitizeCSSSelector(selector); err == nil { + t.Error("Expected error for selector with '}' outside quotes, got none") + } + + selector = `span; @import 'evil.css';` + if _, err := sanitizeCSSSelector(selector); err == nil { + t.Error("Expected error for selector with ';' and '@import' outside quotes, got none") + } + + selector = `div /* comment */ span` + if _, err := sanitizeCSSSelector(selector); err == nil { + t.Error("Expected error for '/*' outside quotes, got none") + } + }) + + t.Run("Balanced vs Unbalanced Quotes/Brackets", func(t *testing.T) { + selector := `div[class^="header"][data-role='main']` + if _, err := sanitizeCSSSelector(selector); err != nil { + t.Errorf("Expected valid selector, got error: %v", err) + } + + selector = `a[href^="https://]` + if _, err := sanitizeCSSSelector(selector); err == nil { + t.Error("Expected error for unbalanced quotes, got none") + } + + selector = `div[class^="header"` + if _, err := sanitizeCSSSelector(selector); err == nil { + t.Error("Expected error for unbalanced brackets, got none") + } + }) + + t.Run("Unicode Escapes", func(t *testing.T) { + selector := `div[class^="\0061"]` + sanitized, err := sanitizeCSSSelector(selector) + if err != nil { + t.Errorf("Expected valid selector, got error: %v", err) + } else if sanitized != `div[class^="a"]` { + t.Errorf("Expected decoded Unicode escape, got %v", sanitized) + } + + selector = `span[data-test="\0062\0063"]` + sanitized, err = sanitizeCSSSelector(selector) + if err != nil { + t.Errorf("Expected valid selector, got error: %v", err) + } else if sanitized != `span[data-test="bc"]` { + t.Errorf("Expected 'bc' after decoding, got %v", sanitized) + } + + selector = `p.note` + if _, err := sanitizeCSSSelector(selector); err != nil { + t.Errorf("Expected valid selector without Unicode escapes, got error: %v", err) + } + }) +} diff --git a/internal/cosmetic/triestore/triestore.go b/internal/cosmetic/triestore/triestore.go new file mode 100644 index 00000000..4973ec6e --- /dev/null +++ b/internal/cosmetic/triestore/triestore.go @@ -0,0 +1,93 @@ +package triestore + +import ( + "strings" + "sync" +) + +type node struct { + children map[string]*node + selectors []string +} + +func newNode() *node { + return &node{ + children: make(map[string]*node), + } +} + +func (n *node) findOrAddChild(segment string) *node { + child := n.children[segment] + if child != nil { + return child + } + + child = newNode() + n.children[segment] = child + return child +} + +// getMatchingSelectors traverses the trie and returns matching selectors. +func (n *node) getMatchingSelectors(segments []string, isWildcard bool) []string { + if len(segments) == 0 { + return n.selectors + } + + var selectors []string + if isWildcard { + // Wildcards can consume as many segments as possible. + selectors = append(selectors, n.getMatchingSelectors(segments[1:], true)...) + } + wildcardChild := n.children["*"] + if wildcardChild != nil { + selectors = append(selectors, wildcardChild.getMatchingSelectors(segments[1:], true)...) + } + exactChild := n.children[segments[0]] + if exactChild != nil { + selectors = append(selectors, exactChild.getMatchingSelectors(segments[1:], false)...) + } + + return selectors +} + +type TrieStore struct { + mu sync.RWMutex + universalSelectors []string + root *node +} + +func NewTrieStore() *TrieStore { + return &TrieStore{ + root: newNode(), + } +} + +func (ts *TrieStore) Add(hostnames []string, selector string) { + ts.mu.Lock() + defer ts.mu.Unlock() + + if len(hostnames) == 0 { + ts.universalSelectors = append(ts.universalSelectors, selector) + return + } + + for _, hostname := range hostnames { + segments := strings.Split(hostname, ".") + + node := ts.root + for _, segment := range segments { + node = node.findOrAddChild(segment) + } + node.selectors = append(node.selectors, selector) + } +} + +func (ts *TrieStore) Get(hostname string) []string { + ts.mu.RLock() + defer ts.mu.RUnlock() + + segments := strings.Split(hostname, ".") + selectors := ts.root.getMatchingSelectors(segments, false) + selectors = append(selectors, ts.universalSelectors...) + return selectors +} diff --git a/internal/cosmetic/triestore/triestore_test.go b/internal/cosmetic/triestore/triestore_test.go new file mode 100644 index 00000000..5ddd2c43 --- /dev/null +++ b/internal/cosmetic/triestore/triestore_test.go @@ -0,0 +1,108 @@ +package triestore + +import ( + "testing" +) + +func TestStore(t *testing.T) { + t.Parallel() + + t.Run("rule is added for all hostnames", func(t *testing.T) { + t.Parallel() + + s := NewTrieStore() + + s.Add([]string{"example.org", "ex.org"}, ".rule") + + if len(s.root.children) != 2 { + t.Errorf("expected 2 children, got %d", len(s.root.children)) + } + }) + + t.Run("multiple rules on same hostname", func(t *testing.T) { + t.Parallel() + + s := NewTrieStore() + + s.Add([]string{"example.org"}, ".rule") + s.Add([]string{"example.org"}, ".rule2") + + if len(s.root.children) != 1 { + t.Errorf("expected 1 child, got %d", len(s.root.children)) + } + }) + + t.Run("add global rule", func(t *testing.T) { + t.Parallel() + + s := NewTrieStore() + s.Add([]string{}, "div.ticket") + s.Add(nil, "div.container") + + rules := s.Get("example.org") + if len(rules) != 2 { + t.Errorf("expected 2 rules, got %d", len(rules)) + } + + rules = s.Get("example.com") + if len(rules) != 2 { + t.Errorf("expected 2 rules, got %d", len(rules)) + } + }) + + t.Run("get rule for hostname", func(t *testing.T) { + t.Parallel() + + s := NewTrieStore() + s.Add([]string{"example.org"}, ".rule") + + rules := s.Get("example.org") + + if len(rules) != 1 { + t.Errorf("expected 1 rule, got %d", len(rules)) + } + }) + + t.Run("match wildcard hostname", func(t *testing.T) { + t.Parallel() + + s := NewTrieStore() + s.Add([]string{"*.example.org"}, ".rule") + + rules := s.Get("mail.example.org") + + if len(rules) != 1 { + t.Error("expected 1 rule") + } + }) + + t.Run("match top level domain wildcard", func(t *testing.T) { + t.Parallel() + + s := NewTrieStore() + s.Add([]string{"example.*"}, ".rule") + + rules := s.Get("example.org") + if len(rules) != 1 { + t.Error("expected 1 rule") + } + + rules = s.Get("example.co.uk") + if len(rules) != 1 { + t.Error("expected 1 rule") + } + }) + + t.Run("match multiple rules", func(t *testing.T) { + t.Parallel() + + s := NewTrieStore() + s.Add([]string{"example.*", "*.co.uk"}, ".rule") + s.Add([]string{"mail.example.co.uk"}, ".rule2") + + rules := s.Get("mail.example.co.uk") + if len(rules) != 2 { + t.Errorf("expected 2 rules, got %d", len(rules)) + } + }) +} diff --git a/internal/filter/filter.go b/internal/filter/filter.go index 6cc99d21..797f5fc7 100644 --- a/internal/filter/filter.go +++ b/internal/filter/filter.go @@ -15,6 +15,7 @@ import ( "time" "github.com/anfragment/zen/internal/cfg" + "github.com/anfragment/zen/internal/cosmetic" "github.com/anfragment/zen/internal/jsrule" "github.com/anfragment/zen/internal/logger" "github.com/anfragment/zen/internal/rule" @@ -46,6 +47,11 @@ type scriptletsInjector interface { AddRule(string, bool) error } +type cosmeticRulesInjector interface { + Inject(*http.Request, *http.Response) error + AddRule(string) error +} + type jsRuleInjector interface { AddRule(rule string) error Inject(*http.Request, *http.Response) error @@ -55,12 +61,13 @@ type jsRuleInjector interface { // // Safe for concurrent use. type Filter struct { - config config - ruleMatcher ruleMatcher - exceptionRuleMatcher ruleMatcher - scriptletsInjector scriptletsInjector - jsRuleInjector jsRuleInjector - eventsEmitter filterEventsEmitter + config config + ruleMatcher ruleMatcher + exceptionRuleMatcher ruleMatcher + scriptletsInjector scriptletsInjector + cosmeticRulesInjector cosmeticRulesInjector + jsRuleInjector jsRuleInjector + eventsEmitter filterEventsEmitter } var ( @@ -70,10 +77,12 @@ var ( exceptionRegex = regexp.MustCompile(`^@@`) // scriptletRegex matches scriptlet rules. scriptletRegex = regexp.MustCompile(`(?:#%#\/\/scriptlet)|(?:##\+js)`) + // cosmeticRuleRegex matches cosmetic rules. + cosmeticRuleRegex = cosmetic.RuleRegex ) // NewFilter creates and initializes a new filter. -func NewFilter(config config, ruleMatcher ruleMatcher, exceptionRuleMatcher ruleMatcher, scriptletsInjector scriptletsInjector, jsRuleInjector jsRuleInjector, eventsEmitter filterEventsEmitter) (*Filter, error) { +func NewFilter(config config, ruleMatcher ruleMatcher, exceptionRuleMatcher ruleMatcher, scriptletsInjector scriptletsInjector, cosmeticRulesInjector cosmeticRulesInjector, jsRuleInjector jsRuleInjector, eventsEmitter filterEventsEmitter) (*Filter, error) { if config == nil { return nil, errors.New("config is nil") } @@ -86,6 +95,9 @@ func NewFilter(config config, ruleMatcher ruleMatcher, exceptionRuleMatcher rule if scriptletsInjector == nil { return nil, errors.New("scriptletsInjector is nil") } + if cosmeticRulesInjector == nil { + return nil, errors.New("cosmeticRulesInjector is nil") + } if jsRuleInjector == nil { return nil, errors.New("jsRuleInjector is nil") } @@ -94,12 +106,13 @@ func NewFilter(config config, ruleMatcher ruleMatcher, exceptionRuleMatcher rule } f := &Filter{ - config: config, - ruleMatcher: ruleMatcher, - exceptionRuleMatcher: exceptionRuleMatcher, - scriptletsInjector: scriptletsInjector, - jsRuleInjector: jsRuleInjector, - eventsEmitter: eventsEmitter, + config: config, + ruleMatcher: ruleMatcher, + exceptionRuleMatcher: exceptionRuleMatcher, + scriptletsInjector: scriptletsInjector, + cosmeticRulesInjector: cosmeticRulesInjector, + jsRuleInjector: jsRuleInjector, + eventsEmitter: eventsEmitter, } f.init() @@ -177,6 +190,13 @@ func (f *Filter) AddRule(rule string, filterListName *string, filterListTrusted } return false, nil } + + if cosmeticRuleRegex.MatchString(rule) { + if err := f.cosmeticRulesInjector.AddRule(rule); err != nil { + return false, fmt.Errorf("add cosmetic rule: %w", err) + } + } + if filterListTrusted && jsrule.RuleRegex.MatchString(rule) { if err := f.jsRuleInjector.AddRule(rule); err != nil { return false, fmt.Errorf("add js rule: %w", err) @@ -246,6 +266,10 @@ func (f *Filter) HandleResponse(req *http.Request, res *http.Response) error { // The error is recoverable, so we log it and continue processing the response. log.Printf("error injecting scriptlets for %q: %v", logger.Redacted(req.URL), err) } + + if err := f.cosmeticRulesInjector.Inject(req, res); err != nil { + log.Printf("error injecting cosmetic rules for %q: %v", logger.Redacted(req.URL), err) + } if err := f.jsRuleInjector.Inject(req, res); err != nil { // The error is recoverable, so we log it and continue processing the response. log.Printf("error injecting js rules for %q: %v", logger.Redacted(req.URL), err)