Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Saghen committed Dec 6, 2024
1 parent 765ee4c commit 3aebeb6
Show file tree
Hide file tree
Showing 21 changed files with 492 additions and 366 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,9 @@ MiniDeps.add({
-- proximity bonus boosts the score of items matching nearby words
use_proximity = true,
max_items = 200,
-- controls which sorts to use and in which order, these three are currently the only allowed options
sorts = { 'label', 'kind', 'score' },
-- controls which sorts to use and in which order, falling back to the next sort if the first one returns nil
-- you may pass a function instead of a string to customize the sorting
sorts = { 'score', 'kind', 'label' },

prebuilt_binaries = {
-- Whether or not to automatically download a prebuilt binary from github. If this is set to `false`
Expand Down
19 changes: 11 additions & 8 deletions lua/blink/cmp/completion/list.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
--- @field select_emitter blink.cmp.EventEmitter<blink.cmp.CompletionListSelectEvent>
--- @field accept_emitter blink.cmp.EventEmitter<blink.cmp.CompletionListAcceptEvent>
---
--- @field show fun(context: blink.cmp.Context, items?: blink.cmp.CompletionItem[])
--- @field fuzzy fun(context: blink.cmp.Context, items: blink.cmp.CompletionItem[]): blink.cmp.CompletionItem[]
--- @field show fun(context: blink.cmp.Context, items: table<string, blink.cmp.CompletionItem[]>)
--- @field fuzzy fun(context: blink.cmp.Context, items: table<string, blink.cmp.CompletionItem[]>): blink.cmp.CompletionItem[]
--- @field hide fun()
---
--- @field get_selected_item fun(): blink.cmp.CompletionItem?
Expand Down Expand Up @@ -59,13 +59,13 @@ local list = {

---------- State ----------

function list.show(context, items)
function list.show(context, items_by_source)
-- reset state for new context
local is_new_context = not list.context or list.context.id ~= context.id
if is_new_context then list.preview_undo_text_edit = nil end

list.context = context
list.items = list.fuzzy(context, items or list.items)
list.items = list.fuzzy(context, items_by_source)

if #list.items == 0 then
list.hide_emitter:emit({ context = context })
Expand All @@ -77,12 +77,15 @@ function list.show(context, items)
list.select(list.config.selection == 'preselect' and 1 or nil, { undo_preview = false })
end

function list.fuzzy(context, items)
function list.fuzzy(context, items_by_source)
local fuzzy = require('blink.cmp.fuzzy')
local sources = require('blink.cmp.sources.lib')
local filtered_items = fuzzy.fuzzy(fuzzy.get_query(), items_by_source)

local filtered_items = fuzzy.fuzzy(fuzzy.get_query(), items)
return sources.apply_max_items_for_completions(context, filtered_items)
-- apply the per source max_items
filtered_items = require('blink.cmp.sources.lib').apply_max_items_for_completions(context, filtered_items)

-- apply the global max_items
return require('blink.cmp.lib.utils').slice(filtered_items, 1, list.config.max_items)
end

function list.hide() list.hide_emitter:emit({ context = list.context }) end
Expand Down
6 changes: 4 additions & 2 deletions lua/blink/cmp/config/fuzzy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@
--- @field use_typo_resistance boolean When enabled, allows for a number of typos relative to the length of the query. Disabling this matches the behavior of fzf
--- @field use_frecency boolean Tracks the most recently/frequently used items and boosts the score of the item
--- @field use_proximity boolean Boosts the score of items matching nearby words
--- @field sorts ("label" | "kind" | "score")[] Controls which sorts to use and in which order, these three are currently the only allowed options
--- @field sorts ("label" | "kind" | "score" | blink.cmp.SortFunction)[] Controls which sorts to use and in which order, these three are currently the only allowed options
--- @field prebuilt_binaries blink.cmp.PrebuiltBinariesConfig

--- @class (exact) blink.cmp.PrebuiltBinariesConfig
--- @field download boolean Whenther or not to automatically download a prebuilt binary from github. If this is set to `false` you will need to manually build the fuzzy binary dependencies by running `cargo build --release`
--- @field force_version? string When downloading a prebuilt binary, force the downloader to resolve this version. If this is unset then the downloader will attempt to infer the version from the checked out git tag (if any). WARN: Beware that `main` may be incompatible with the version you select
--- @field force_system_triple? string When downloading a prebuilt binary, force the downloader to use this system triple. If this is unset then the downloader will attempt to infer the system triple from `jit.os` and `jit.arch`. Check the latest release for all available system triples. WARN: Beware that `main` may be incompatible with the version you select

--- @alias blink.cmp.SortFunction fun(a: blink.cmp.CompletionItem, b: blink.cmp.CompletionItem): boolean | nil

local validate = require('blink.cmp.config.utils').validate
local fuzzy = {
--- @type blink.cmp.FuzzyConfig
default = {
use_typo_resistance = true,
use_frecency = true,
use_proximity = true,
sorts = { 'label', 'kind', 'score' },
sorts = { 'score', 'kind', 'label' },
prebuilt_binaries = {
download = true,
force_version = nil,
Expand Down
1 change: 1 addition & 0 deletions lua/blink/cmp/config/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
--- @field blocked_filetypes string[]
--- @field keymap blink.cmp.KeymapConfig
--- @field completion blink.cmp.CompletionConfig
--- @field fuzzy blink.cmp.FuzzyConfig
--- @field sources blink.cmp.SourceConfig
--- @field signature blink.cmp.SignatureConfig
--- @field snippets blink.cmp.SnippetsConfig
Expand Down
9 changes: 5 additions & 4 deletions lua/blink/cmp/config/sources.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
--- @field module? string
--- @field enabled? boolean | fun(ctx?: blink.cmp.Context): boolean Whether or not to enable the provider
--- @field opts? table
--- @field async? boolean Whether blink should wait for the source to return before showing the completions
--- @field transform_items? fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): blink.cmp.CompletionItem[] Function to transform the items before they're returned
--- @field should_show_items? boolean | number | fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): boolean Whether or not to show the items
--- @field max_items? number | fun(ctx: blink.cmp.Context, enabled_sources: string[], items: blink.cmp.CompletionItem[]): number Maximum number of items to display in the menu
--- @field min_keyword_length? number | fun(ctx: blink.cmp.Context, enabled_sources: string[]): number Minimum number of characters in the keyword to trigger the provider
--- @field fallback_for? string[] | fun(ctx: blink.cmp.Context, enabled_sources: string[]): string[] If any of these providers return 0 items, it will fallback to this provider
--- @field min_keyword_length? number | fun(ctx: blink.cmp.Context): number Minimum number of characters in the keyword to trigger the provider
--- @field fallbacks? string[] | fun(ctx: blink.cmp.Context, enabled_sources: string[]): string[] If this provider returns 0 items, it will fallback to these providers
--- @field score_offset? number | fun(ctx: blink.cmp.Context, enabled_sources: string[]): number Boost/penalize the score of the items
--- @field deduplicate? blink.cmp.DeduplicateConfig TODO: implement
--- @field override? blink.cmp.SourceOverride Override the source's functions
Expand All @@ -45,6 +46,7 @@ local sources = {
lsp = {
name = 'LSP',
module = 'blink.cmp.sources.lsp',
fallbacks = { 'buffer' },
},
path = {
name = 'Path',
Expand All @@ -64,7 +66,6 @@ local sources = {
buffer = {
name = 'Buffer',
module = 'blink.cmp.sources.buffer',
fallback_for = { 'lsp' },
},
},
},
Expand All @@ -88,7 +89,7 @@ function sources.validate(config)
should_show_items = { provider.should_show_items, { 'boolean', 'function' }, true },
max_items = { provider.max_items, { 'number', 'function' }, true },
min_keyword_length = { provider.min_keyword_length, { 'number', 'function' }, true },
fallback_for = { provider.fallback_for, { 'table', 'function' }, true },
fallbacks = { provider.fallback_for, { 'table', 'function' }, true },
score_offset = { provider.score_offset, { 'number', 'function' }, true },
deduplicate = { provider.deduplicate, 'table', true },
override = { provider.override, 'table', true },
Expand Down
60 changes: 13 additions & 47 deletions lua/blink/cmp/fuzzy/fuzzy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::lsp_item::LspItem;
use mlua::prelude::*;
use mlua::FromLua;
use mlua::Lua;
use std::cmp::Reverse;
use std::collections::HashSet;

#[derive(Clone, Hash)]
Expand All @@ -15,8 +14,6 @@ pub struct FuzzyOptions {
use_proximity: bool,
nearby_words: Option<Vec<String>>,
min_score: u16,
max_items: u32,
sorts: Vec<String>,
}

impl FromLua for FuzzyOptions {
Expand All @@ -27,17 +24,13 @@ impl FromLua for FuzzyOptions {
let use_proximity: bool = tab.get("use_proximity").unwrap_or_default();
let nearby_words: Option<Vec<String>> = tab.get("nearby_words").ok();
let min_score: u16 = tab.get("min_score").unwrap_or_default();
let max_items: u32 = tab.get("max_items").unwrap_or_default();
let sorts: Vec<String> = tab.get("sorts").unwrap_or_default();

Ok(FuzzyOptions {
use_typo_resistance,
use_frecency,
use_proximity,
nearby_words,
min_score,
max_items,
sorts,
})
} else {
Err(mlua::Error::FromLuaConversionError {
Expand All @@ -51,10 +44,10 @@ impl FromLua for FuzzyOptions {

pub fn fuzzy(
needle: String,
haystack: Vec<LspItem>,
haystack: &Vec<LspItem>,
frecency: &FrecencyTracker,
opts: FuzzyOptions,
) -> Vec<usize> {
) -> (Vec<i32>, Vec<u32>) {
let nearby_words: HashSet<String> = HashSet::from_iter(opts.nearby_words.unwrap_or_default());
let haystack_labels = haystack
.iter()
Expand Down Expand Up @@ -110,42 +103,15 @@ pub fn fuzzy(
.collect::<Vec<_>>();
}

// Sort matches by sort criteria
for sort in opts.sorts.iter() {
match sort.as_str() {
"kind" => {
matches.sort_by_key(|mtch| haystack[mtch.index_in_haystack].kind);
}
"score" => {
matches.sort_by_cached_key(|mtch| Reverse(match_scores[mtch.index]));
}
"label" => {
matches.sort_by(|a, b| {
let label_a = haystack[a.index_in_haystack]
.sort_text
.as_ref()
.unwrap_or(&haystack[a.index_in_haystack].label);
let label_b = haystack[b.index_in_haystack]
.sort_text
.as_ref()
.unwrap_or(&haystack[b.index_in_haystack].label);

// Put anything with an underscore at the end
match (label_a.starts_with('_'), label_b.starts_with('_')) {
(true, false) => std::cmp::Ordering::Greater,
(false, true) => std::cmp::Ordering::Less,
_ => label_a.cmp(label_b),
}
});
}
_ => {}
}
}

// Grab the top N matches and return the indices
matches
.iter()
.map(|mtch| mtch.index_in_haystack)
.take(opts.max_items as usize)
.collect::<Vec<_>>()
// Return scores and indices
(
matches
.iter()
.map(|mtch| match_scores[mtch.index] as i32)
.collect::<Vec<_>>(),
matches
.iter()
.map(|mtch| mtch.index_in_haystack as u32)
.collect::<Vec<_>>(),
)
}
55 changes: 33 additions & 22 deletions lua/blink/cmp/fuzzy/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ local config = require('blink.cmp.config')

local fuzzy = {
rust = require('blink.cmp.fuzzy.rust'),
haystacks_by_provider_cache = {},
has_init_db = false,
}

Expand All @@ -24,13 +25,19 @@ function fuzzy.get_words(lines) return fuzzy.rust.get_words(lines) end

function fuzzy.fuzzy_matched_indices(needle, haystack) return fuzzy.rust.fuzzy_matched_indices(needle, haystack) end

---@param needle string
---@param haystack blink.cmp.CompletionItem[]?
---@return blink.cmp.CompletionItem[]
function fuzzy.fuzzy(needle, haystack)
--- @param needle string
--- @param haystacks_by_provider table<string, blink.cmp.CompletionItem[]>
--- @return blink.cmp.CompletionItem[]
function fuzzy.fuzzy(needle, haystacks_by_provider)
fuzzy.init_db()

haystack = haystack or {}
for provider_id, haystack in pairs(haystacks_by_provider) do
-- set the provider items once since Lua <-> Rust takes the majority of the time
if fuzzy.haystacks_by_provider_cache[provider_id] ~= haystack then
fuzzy.haystacks_by_provider_cache[provider_id] = haystack
fuzzy.rust.set_provider_items(provider_id, haystack)
end
end

-- get the nearby words
local cursor_row = vim.api.nvim_win_get_cursor(0)[1]
Expand All @@ -39,25 +46,29 @@ function fuzzy.fuzzy(needle, haystack)
local nearby_text = table.concat(vim.api.nvim_buf_get_lines(0, start_row, end_row, false), '\n')
local nearby_words = #nearby_text < 10000 and fuzzy.rust.get_words(nearby_text) or {}

-- perform fuzzy search
local matched_indices = fuzzy.rust.fuzzy(needle, haystack, {
-- each matching char is worth 4 points and it receives a bonus for capitalization, delimiter and prefix
-- so this should generally be good
-- TODO: make this configurable
min_score = config.fuzzy.use_typo_resistance and (6 * needle:len()) or 0,
max_items = config.completion.list.max_items,
use_typo_resistance = config.fuzzy.use_typo_resistance,
use_frecency = config.fuzzy.use_frecency,
use_proximity = config.fuzzy.use_proximity,
sorts = config.fuzzy.sorts,
nearby_words = nearby_words,
})

local filtered_items = {}
for _, idx in ipairs(matched_indices) do
table.insert(filtered_items, haystack[idx + 1])
for provider_id, haystack in pairs(haystacks_by_provider) do
-- perform fuzzy search
local scores, matched_indices = fuzzy.rust.fuzzy(needle, provider_id, {
-- each matching char is worth 4 points and it receives a bonus for capitalization, delimiter and prefix
-- so this should generally be good
-- TODO: make this configurable
min_score = config.fuzzy.use_typo_resistance and (6 * needle:len()) or 0,
use_typo_resistance = config.fuzzy.use_typo_resistance,
use_frecency = config.fuzzy.use_frecency,
use_proximity = config.fuzzy.use_proximity,
sorts = config.fuzzy.sorts,
nearby_words = nearby_words,
})

for idx, item_index in ipairs(matched_indices) do
local item = haystack[item_index + 1]
item.score = scores[idx]
table.insert(filtered_items, item)
end
end
return filtered_items

return require('blink.cmp.fuzzy.sort').sort(filtered_items)
end

--- Gets the text under the cursor to be used for fuzzy matching
Expand Down
38 changes: 31 additions & 7 deletions lua/blink/cmp/fuzzy/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::lsp_item::LspItem;
use lazy_static::lazy_static;
use mlua::prelude::*;
use regex::Regex;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;

mod frecency;
Expand All @@ -14,6 +14,8 @@ mod lsp_item;
lazy_static! {
static ref REGEX: Regex = Regex::new(r"\p{L}[\p{L}0-9_\\-]{2,32}").unwrap();
static ref FRECENCY: RwLock<Option<FrecencyTracker>> = RwLock::new(None);
static ref HAYSTACKS_BY_PROVIDER: RwLock<HashMap<String, Vec<LspItem>>> =
RwLock::new(HashMap::new());
}

pub fn init_db(_: &Lua, db_path: String) -> LuaResult<bool> {
Expand Down Expand Up @@ -52,21 +54,39 @@ pub fn access(_: &Lua, item: LspItem) -> LuaResult<bool> {
Ok(true)
}

pub fn set_provider_items(
_: &Lua,
(provider_id, items): (String, Vec<LspItem>),
) -> LuaResult<bool> {
let mut items_by_provider = HAYSTACKS_BY_PROVIDER.write().map_err(|_| {
mlua::Error::RuntimeError("Failed to acquire lock for items by provider".to_string())
})?;
items_by_provider.insert(provider_id, items);
Ok(true)
}

pub fn fuzzy(
_lua: &Lua,
(needle, haystack, opts): (String, Vec<LspItem>, FuzzyOptions),
) -> LuaResult<Vec<u32>> {
(needle, provider_id, opts): (String, String, FuzzyOptions),
) -> LuaResult<(Vec<i32>, Vec<u32>)> {
let mut frecency_handle = FRECENCY.write().map_err(|_| {
mlua::Error::RuntimeError("Failed to acquire lock for frecency".to_string())
})?;
let frecency = frecency_handle.as_mut().ok_or_else(|| {
mlua::Error::RuntimeError("Attempted to use frencecy before initialization".to_string())
})?;

Ok(fuzzy::fuzzy(needle, haystack, frecency, opts)
.into_iter()
.map(|i| i as u32)
.collect())
let haystacks_by_provider = HAYSTACKS_BY_PROVIDER.read().map_err(|_| {
mlua::Error::RuntimeError("Failed to acquire lock for items by provider".to_string())
})?;
let haystack = haystacks_by_provider.get(&provider_id).ok_or_else(|| {
mlua::Error::RuntimeError(format!(
"Attempted to fuzzy match for provider {} before setting the provider's items",
provider_id
))
})?;

Ok(fuzzy::fuzzy(needle, haystack, frecency, opts))
}

pub fn fuzzy_matched_indices(
Expand All @@ -93,6 +113,10 @@ pub fn get_words(_: &Lua, text: String) -> LuaResult<Vec<String>> {
#[mlua::lua_module(skip_memory_check)]
fn blink_cmp_fuzzy(lua: &Lua) -> LuaResult<LuaTable> {
let exports = lua.create_table()?;
exports.set(
"set_provider_items",
lua.create_function(set_provider_items)?,
)?;
exports.set("fuzzy", lua.create_function(fuzzy)?)?;
exports.set(
"fuzzy_matched_indices",
Expand Down
Loading

0 comments on commit 3aebeb6

Please sign in to comment.