From 1d2f68a589c7ea47d9b77c8c8dddfe07c088cb4c Mon Sep 17 00:00:00 2001 From: Josh Bode Date: Mon, 6 Nov 2023 14:51:58 +1100 Subject: [PATCH] Use keyword_pattern for is_symbol check --- lua/cmp/entry.lua | 8 +++++--- lua/cmp/matcher.lua | 18 ++++++++++++++---- lua/cmp/matcher_spec.lua | 15 +++++++++++++++ lua/cmp/source.lua | 2 +- 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/lua/cmp/entry.lua b/lua/cmp/entry.lua index cc89acb06..62b9aa611 100644 --- a/lua/cmp/entry.lua +++ b/lua/cmp/entry.lua @@ -389,8 +389,9 @@ end ---Match line. ---@param input string ---@param matching_config cmp.MatchingConfig +---@param keyword_pattern string ---@return { score: integer, matches: table[] } -entry.match = function(self, input, matching_config) +entry.match = function(self, input, matching_config, keyword_pattern) -- https://www.lua.org/pil/11.6.html -- do not use '..' to allocate multiple strings local cache_key = string.format('%s:%d:%d:%d:%d:%d:%d', input, self.resolved_completion_item and 1 or 0, matching_config.disallow_fuzzy_matching and 1 or 0, matching_config.disallow_partial_matching and 1 or 0, matching_config.disallow_prefix_unmatching and 1 or 0, matching_config.disallow_partial_fuzzy_matching and 1 or 0, matching_config.disallow_symbol_nonprefix_matching and 1 or 0) @@ -403,13 +404,13 @@ entry.match = function(self, input, matching_config) end return matched end - matched = self:_match(input, matching_config) + matched = self:_match(input, matching_config, keyword_pattern) self.match_cache:set(cache_key, matched) return matched end ---@package -entry._match = function(self, input, matching_config) +entry._match = function(self, input, matching_config, keyword_pattern) local completion_item = self.completion_item local option = { disallow_fuzzy_matching = matching_config.disallow_fuzzy_matching, @@ -421,6 +422,7 @@ entry._match = function(self, input, matching_config) self.word, self.completion_item.label, }, + keyword_pattern = keyword_pattern, } local score, matches, filter_text diff --git a/lua/cmp/matcher.lua b/lua/cmp/matcher.lua index b5cbf4ca8..cc85cea05 100644 --- a/lua/cmp/matcher.lua +++ b/lua/cmp/matcher.lua @@ -1,4 +1,5 @@ local char = require('cmp.utils.char') +local pattern = require('cmp.utils.pattern') local matcher = {} @@ -81,7 +82,7 @@ end ---Match entry ---@param input string ---@param word string ----@param option { synonyms: string[], disallow_fullfuzzy_matching: boolean, disallow_fuzzy_matching: boolean, disallow_partial_fuzzy_matching: boolean, disallow_partial_matching: boolean, disallow_prefix_unmatching: boolean, disallow_symbol_nonprefix_matching: boolean } +---@param option { synonyms: string[], disallow_fullfuzzy_matching: boolean, disallow_fuzzy_matching: boolean, disallow_partial_fuzzy_matching: boolean, disallow_partial_matching: boolean, disallow_prefix_unmatching: boolean, disallow_symbol_nonprefix_matching: boolean, keyword_pattern: string|nil } ---@return integer, table matcher.match = function(input, word, option) option = option or {} @@ -103,6 +104,15 @@ matcher.match = function(input, word, option) end end + local is_symbol + if option.keyword_pattern ~= nil then + is_symbol = function(byte) + return pattern.matchstr(option.keyword_pattern, string.char(byte)) == nil + end + else + is_symbol = char.is_symbol + end + -- Gather matched regions local matches = {} local input_start_index = 1 @@ -111,7 +121,7 @@ matcher.match = function(input, word, option) local word_bound_index = 1 local no_symbol_match = false while input_end_index <= #input and word_index <= #word do - local m = matcher.find_match_region(input, input_start_index, input_end_index, word, word_index) + local m = matcher.find_match_region(input, input_start_index, input_end_index, word, word_index, is_symbol) if m and input_end_index <= m.input_match_end then m.index = word_bound_index input_start_index = m.input_match_start + 1 @@ -277,7 +287,7 @@ matcher.fuzzy = function(input, word, matches, option) end --- find_match_region -matcher.find_match_region = function(input, input_start_index, input_end_index, word, word_index) +matcher.find_match_region = function(input, input_start_index, input_end_index, word, word_index, is_symbol) -- determine input position ( woroff -> word_offset ) while input_start_index < input_end_index do if char.match(string.byte(input, input_end_index), string.byte(word, word_index)) then @@ -309,7 +319,7 @@ matcher.find_match_region = function(input, input_start_index, input_end_index, strict_count = strict_count + (c1 == c2 and 1 or 0) match_count = match_count + 1 word_offset = word_offset + 1 - no_symbol_match = no_symbol_match or char.is_symbol(c1) + no_symbol_match = no_symbol_match or is_symbol(c1) else -- Match end (partial region) if input_match_start ~= -1 then diff --git a/lua/cmp/matcher_spec.lua b/lua/cmp/matcher_spec.lua index e79715d48..6c9593cf6 100644 --- a/lua/cmp/matcher_spec.lua +++ b/lua/cmp/matcher_spec.lua @@ -22,6 +22,21 @@ describe('matcher', function() assert.is.truthy(matcher.match('fmodify', 'fnamemodify', config.matching) >= 1) assert.is.truthy(matcher.match('candlesingle', 'candle#accept#single', config.matching) >= 1) + local options = { + keyword_pattern = [[ \w\+ ]], + } + assert.is.truthy(matcher.match('ab', 'a_b_c', options) > matcher.match('ac', 'a_b_c', options)) + assert.is.truthy(matcher.match('a_b', 'a_b_c', options) > matcher.match('ab', 'a_b_c', options)) + assert.is.truthy(matcher.match('a_b/c', 'a_b/c', options) > matcher.match('a/c', 'a_b/c', options)) + + assert.is.truthy(matcher.match('bora', 'border-radius') >= 1) + assert.is.truthy(matcher.match('woroff', 'word_offset') >= 1) + assert.is.truthy(matcher.match('call', 'call') > matcher.match('call', 'condition_all')) + assert.is.truthy(matcher.match('Buffer', 'Buffer') > matcher.match('Buffer', 'buffer')) + assert.is.truthy(matcher.match('luacon', 'lua_context') > matcher.match('luacon', 'LuaContext')) + assert.is.truthy(matcher.match('fmodify', 'fnamemodify') >= 1) + assert.is.truthy(matcher.match('candlesingle', 'candle#accept#single') >= 1) + assert.is.truthy(matcher.match('vi', 'void#', config.matching) >= 1) assert.is.truthy(matcher.match('vo', 'void#', config.matching) >= 1) assert.is.truthy(matcher.match('var_', 'var_dump', config.matching) >= 1) diff --git a/lua/cmp/source.lua b/lua/cmp/source.lua index 74bd3d092..0f0819e46 100644 --- a/lua/cmp/source.lua +++ b/lua/cmp/source.lua @@ -120,7 +120,7 @@ source.get_entries = function(self, ctx) inputs[o] = string.sub(ctx.cursor_before_line, o) end - local match = e:match(inputs[o], matching_config) + local match = e:match(inputs[o], matching_config, self:get_keyword_pattern()) e.score = match.score e.exact = false if e.score >= 1 then