diff --git a/mteb/evaluation/evaluators/RetrievalEvaluator.py b/mteb/evaluation/evaluators/RetrievalEvaluator.py index 70f26e2236..ed3a50d71f 100644 --- a/mteb/evaluation/evaluators/RetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/RetrievalEvaluator.py @@ -307,15 +307,17 @@ def search_cross_encoder( assert ( len(queries_in_pair) == len(corpus_in_pair) == len(instructions_in_pair) ) + corpus_in_pair = corpus_to_str(list(corpus_in_pair)) if hasattr(self.model, "model") and isinstance( self.model.model, CrossEncoder ): # can't take instructions, so add them here - queries_in_pair = [ - f"{q} {i}".strip() - for i, q in zip(instructions_in_pair, queries_in_pair) - ] + if instructions_in_pair[0] is not None: + queries_in_pair = [ + f"{q} {i}".strip() + for i, q in zip(instructions_in_pair, queries_in_pair) + ] scores = self.model.predict(list(zip(queries_in_pair, corpus_in_pair))) # type: ignore else: # may use the instructions in a unique way, so give them also diff --git a/mteb/models/rerankers_custom.py b/mteb/models/rerankers_custom.py index 40977f1e04..e8bb483a3d 100644 --- a/mteb/models/rerankers_custom.py +++ b/mteb/models/rerankers_custom.py @@ -22,6 +22,7 @@ def __init__( batch_size: int = 4, fp_options: bool = None, silent: bool = False, + **kwargs, ): self.model_name_or_path = model_name_or_path self.batch_size = batch_size @@ -34,7 +35,7 @@ def __init__( self.fp_options = torch.float32 elif self.fp_options == "bfloat16": self.fp_options = torch.bfloat16 - print(f"Using fp_options of {self.fp_options}") + logger.info(f"Using fp_options of {self.fp_options}") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.silent = silent self.first_print = True # for debugging @@ -70,7 +71,12 @@ def __init__( @torch.inference_mode() def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs if instructions is not None and instructions[0] is not None: assert len(instructions) == len(queries) queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)] @@ -112,7 +118,13 @@ def __init__( @torch.inference_mode() def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs + if instructions is not None and instructions[0] is not None: queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)] @@ -152,7 +164,13 @@ def __init__( ) def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs + if instructions is not None and instructions[0] is not None: queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)] @@ -179,7 +197,7 @@ def loader_inner(**kwargs: Any) -> Encoder: _loader, wrapper=MonoBERTReranker, model_name_or_path="castorini/monobert-large-msmarco", - fp_options="float1616", + fp_options="float16", ), name="castorini/monobert-large-msmarco", languages=["eng_Latn"], @@ -194,7 +212,7 @@ def loader_inner(**kwargs: Any) -> Encoder: _loader, wrapper=JinaReranker, model_name_or_path="jinaai/jina-reranker-v2-base-multilingual", - fp_options="float1616", + fp_options="float16", ), name="jinaai/jina-reranker-v2-base-multilingual", languages=["eng_Latn"], @@ -208,7 +226,7 @@ def loader_inner(**kwargs: Any) -> Encoder: _loader, wrapper=BGEReranker, model_name_or_path="BAAI/bge-reranker-v2-m3", - fp_options="float1616", + fp_options="float16", ), name="BAAI/bge-reranker-v2-m3", languages=[ diff --git a/mteb/models/rerankers_monot5_based.py b/mteb/models/rerankers_monot5_based.py index 7ece40e3cf..d72a893406 100644 --- a/mteb/models/rerankers_monot5_based.py +++ b/mteb/models/rerankers_monot5_based.py @@ -105,7 +105,12 @@ def get_prediction_tokens( @torch.inference_mode() def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs if instructions is not None and instructions[0] is not None: queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)] @@ -194,7 +199,13 @@ def __init__( @torch.inference_mode() def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs + if instructions is not None and instructions[0] is not None: # logger.info(f"Adding instructions to LLAMA queries") queries = [