diff --git a/src/lib.rs b/src/lib.rs index 0203acf..a343dcb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -164,11 +164,11 @@ const MAX_NUM_THREADS: usize = 128; struct CoreBPE { encoder: HashMap, Rank>, special_tokens_encoder: HashMap, - decoder: HashMap>, + decoder: HashMap, special_tokens_decoder: HashMap>, regex_tls: Vec, special_regex_tls: Vec, - sorted_token_bytes: Vec>, + sorted_token_bytes: Vec<&'static [u8]>, } impl CoreBPE { @@ -191,6 +191,7 @@ impl CoreBPE { None => self .special_tokens_decoder .get(&token) + .map(|v| v.as_slice()) .ok_or(DecodeKeyError { token })?, }; ret.extend(token_bytes); @@ -341,12 +342,12 @@ impl CoreBPE { // Separating this from the loop below helps with performance in a common case. let mut point = self .sorted_token_bytes - .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); + .partition_point(|x| *x < unstable_bytes.as_slice()); while point < self.sorted_token_bytes.len() && self.sorted_token_bytes[point].starts_with(&unstable_bytes) { completions.insert(vec![ - self.encoder[self.sorted_token_bytes[point].as_slice()], + self.encoder[self.sorted_token_bytes[point]], ]); point += 1; } @@ -359,12 +360,12 @@ impl CoreBPE { let suffix = &unstable_bytes[i..]; let mut point = self .sorted_token_bytes - .partition_point(|x| x.as_slice() < suffix); + .partition_point(|x| *x < suffix); // TODO: Perf optimisation if suffix starts with " "? while point < self.sorted_token_bytes.len() && self.sorted_token_bytes[point].starts_with(suffix) { - let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); + let possibility = [prefix, self.sorted_token_bytes[point]].concat(); let encoded = match std::str::from_utf8(&possibility) { // Morally, this is byte_pair_encode(&possibility, &self.encoder) // But we might have introduced a regex split which would prevent merges. @@ -447,8 +448,14 @@ impl CoreBPE { .map_err(|e| PyErr::new::(e.to_string()))? }; - let decoder: HashMap> = - encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); + let decoder: HashMap = encoder + .iter() + .map(|(k, v)| { + let bytes: &[u8] = k.as_slice(); + let static_bytes: &'static [u8] = unsafe { std::mem::transmute(bytes) }; + (*v, static_bytes) + }) + .collect(); assert!( encoder.len() == decoder.len(), @@ -460,8 +467,14 @@ impl CoreBPE { .map(|(k, v)| (*v, k.as_bytes().to_vec())) .collect(); - // Clone because I don't know how to tell Rust I'm not going to change the map - let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); + let mut sorted_token_bytes: Vec<&'static [u8]> = encoder + .keys() + .map(|k| { + let bytes: &[u8] = k.as_slice(); + let static_bytes: &'static [u8] = unsafe { std::mem::transmute(bytes) }; + static_bytes + }) + .collect(); sorted_token_bytes.sort(); Ok(CoreBPE {