From eac31b4797ce4fa9dd546f7b98ec32305527b19e Mon Sep 17 00:00:00 2001 From: kamalsacranie <66880161+kamalsacranie@users.noreply.github.com> Date: Sun, 21 Jan 2024 02:09:27 +0000 Subject: [PATCH] feat: allow rules to be treesitter context aware (#423) (#424) * feat: allow rules to be treesitter context aware When a rule is defined with the `:with_context()` method, has a specified filetype, and is operating in a buffer with a treesitter parser attached, the rule will only execute iff the treesitter language at the cursor matches one of the filetypes specified in the initial rule definition. > If there are no specified filetypes, of there is no parser attached to > the current buffer, the rule executes as normal * Add tests for treesitter context in markdown sample - Add 'ts_context markdown `*` success md_context' - Add 'ts_context codeblock `*` fail js_context' --- lua/nvim-autopairs/ts-conds.lua | 10 ++++++++++ lua/nvim-autopairs/ts-utils.lua | 12 ++++++++++++ tests/endwise/sample.md | 9 +++++++++ tests/test_utils.lua | 4 ++++ tests/treesitter_spec.lua | 32 +++++++++++++++++++++++++++++++- 5 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 tests/endwise/sample.md diff --git a/lua/nvim-autopairs/ts-conds.lua b/lua/nvim-autopairs/ts-conds.lua index dd963c57..d7ef8e0e 100644 --- a/lua/nvim-autopairs/ts-conds.lua +++ b/lua/nvim-autopairs/ts-conds.lua @@ -147,4 +147,14 @@ conds.is_not_ts_node_comment = function() end end +conds.is_not_in_context = function() + return function(opts) + local context = require("nvim-autopairs.ts-utils") + .get_language_tree_at_position({ utils.get_cursor() }) + if not vim.tbl_contains(opts.rule.filetypes, context:lang()) then + return false + end + end +end + return conds diff --git a/lua/nvim-autopairs/ts-utils.lua b/lua/nvim-autopairs/ts-utils.lua index 2a6e784d..daa92623 100644 --- a/lua/nvim-autopairs/ts-utils.lua +++ b/lua/nvim-autopairs/ts-utils.lua @@ -1,6 +1,18 @@ local ts_get_node_text = vim.treesitter.get_node_text or vim.treesitter.query.get_node_text local M = {} +--- Returns the language tree at the given position. +---@return LanguageTree +function M.get_language_tree_at_position(position) + local language_tree = vim.treesitter.get_parser() + language_tree:for_each_tree(function(_, tree) + if tree:contains(vim.tbl_flatten({ position, position })) then + language_tree = tree + end + end) + return language_tree +end + function M.get_tag_name(node) local tag_name = nil if node ~=nil then diff --git a/tests/endwise/sample.md b/tests/endwise/sample.md new file mode 100644 index 00000000..490b6d83 --- /dev/null +++ b/tests/endwise/sample.md @@ -0,0 +1,9 @@ +# Example Markdown File + +```javascript +let; +let; +let; +let; +let; +``` diff --git a/tests/test_utils.lua b/tests/test_utils.lua index 02b9a6ee..e259d99e 100644 --- a/tests/test_utils.lua +++ b/tests/test_utils.lua @@ -133,6 +133,10 @@ _G.Test_withfile = function(test_data, cb) vim.bo.filetype = value.filetype end end + local status, parser = pcall(vim.treesitter.get_parser, 0) + if status then + parser:parse(true) + end vim.api.nvim_buf_set_lines( 0, value.linenr - 1, diff --git a/tests/treesitter_spec.lua b/tests/treesitter_spec.lua index 4c877731..d86dd61a 100644 --- a/tests/treesitter_spec.lua +++ b/tests/treesitter_spec.lua @@ -13,7 +13,7 @@ vim.api.nvim_set_keymap( ) ts.setup({ - ensure_installed = { 'lua', 'javascript', 'rust' }, + ensure_installed = { 'lua', 'javascript', 'rust', 'markdown', 'markdown_inline' }, highlight = { enable = true }, autopairs = { enable = true }, }) @@ -93,6 +93,36 @@ local data = { before = [[pub fn noop(_inp: Vec|) {]], after = [[pub fn noop(_inp: Vec<|>) {]], }, + { + setup_func = function() + npairs.add_rules({ + Rule('*', '*', { 'markdown', 'markdown_inline' }) + :with_pair(ts_conds.is_not_in_context()), + }) + end, + name = 'ts_context markdown `*` success md_context', + filepath = './tests/endwise/sample.md', + linenr = 2, + filetype = 'markdown', + key = '*', + before = [[|]], + after = [[*|*]], + }, + { + setup_func = function() + npairs.add_rules({ + Rule('*', '*', { 'markdown', 'markdown_inline' }) + :with_pair(ts_conds.is_not_in_context()), + }) + end, + name = 'ts_context codeblock `*` fail js_context', + filepath = './tests/endwise/sample.md', + linenr = 6, + filetype = 'markdown', + key = '*', + before = [[let calc = 1 |]], + after = [[let calc = 1 *|]], + }, } local run_data = _G.Test_filter(data)