diff --git a/vamb/__main__.py b/vamb/__main__.py index cf513460..eba75b01 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -18,7 +18,6 @@ from torch.utils.data import DataLoader from functools import partial from loguru import logger -from array import array _ncpu = os.cpu_count() DEFAULT_THREADS = 8 if _ncpu is None else min(_ncpu, 8) @@ -1413,7 +1412,7 @@ def predict_taxonomy( predicted_vector[i] > taxonomy_options.softmax_threshold ) ranks = list(nodes_ar[threshold_mask][1:]) - probs = array("f", predicted_vector[i][threshold_mask][1:]) + probs = predicted_vector[i][threshold_mask][1:] tax = vamb.taxonomy.PredictedContigTaxonomy( vamb.taxonomy.ContigTaxonomy(ranks), probs ) diff --git a/vamb/taxonomy.py b/vamb/taxonomy.py index 55cba22b..7e88919d 100644 --- a/vamb/taxonomy.py +++ b/vamb/taxonomy.py @@ -1,7 +1,7 @@ from typing import Optional, IO from pathlib import Path from vamb.parsecontigs import CompositionMetaData -import array +import numpy as np class ContigTaxonomy: @@ -121,7 +121,7 @@ def parse_tax_file( class PredictedContigTaxonomy: slots = ["contig_taxonomy", "probs"] - def __init__(self, tax: ContigTaxonomy, probs: array.array[float]): + def __init__(self, tax: ContigTaxonomy, probs: np.ndarray): if len(probs) != len(tax.ranks): raise ValueError("The length of probs must equal that of ranks") self.tax = tax