Skip to content

Commit

Permalink
fix: cuda utilization
Browse files Browse the repository at this point in the history
  • Loading branch information
sgalkina authored and jakobnissen committed Sep 23, 2024
1 parent bdd14d1 commit 852ee3c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 16 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/cli_vamb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
cat outdir_vamb/log.txt
- name: Run TaxVAMB
run: |
vamb bin taxvamb --outdir outdir_taxvamb --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy taxonomy_mock.tsv -pe 10 -pt 10 -e 10 -q -pq -t 10 -o C --minfasta 200000
vamb bin taxvamb --outdir outdir_taxvamb --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy taxonomy_mock.tsv -pe 10 -pt 10 -e 10 -q -t 10 -o C --minfasta 200000
ls -la outdir_taxvamb
cat outdir_taxvamb/log.txt
vamb bin taxvamb --outdir outdir_taxvamb_no_predict --no_predictor --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy taxonomy_mock.tsv -e 10 -q -t 10 -o C --minfasta 200000
Expand All @@ -53,10 +53,10 @@ jobs:
cat outdir_taxvamb_preds/log.txt
- name: Run Taxometer
run: |
vamb taxometer --outdir outdir_taxometer --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy taxonomy_mock.tsv -pe 10 -pq -pt 10
vamb taxometer --outdir outdir_taxometer --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy taxonomy_mock.tsv -pe 10 -pt 10
ls -la outdir_taxometer
cat outdir_taxometer/log.txt
vamb taxometer --outdir outdir_taxometer_pred --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy outdir_taxometer/results_taxometer.tsv -pe 10 -pq -pt 10
vamb taxometer --outdir outdir_taxometer_pred --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy outdir_taxometer/results_taxometer.tsv -pe 10 -pt 10
ls -la outdir_taxometer_pred
cat outdir_taxometer/log.txt
- name: Run k-means reclustering
Expand Down
11 changes: 1 addition & 10 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def from_args_taxometer(cls, args: argparse.Namespace):
return cls(
typeasserted(args.pred_nepochs, int),
typeasserted(args.pred_batchsize, int),
typeasserted(args.pred_batchsteps, list),
[],
)

def __init__(
Expand Down Expand Up @@ -1857,15 +1857,6 @@ def add_predictor_arguments(subparser):
default=1024,
help=argparse.SUPPRESS,
)
pred_trainos.add_argument(
"-pq",
dest="pred_batchsteps",
metavar="",
type=int,
nargs="*",
default=[],
help=argparse.SUPPRESS,
)
pred_trainos.add_argument(
"-pthr",
dest="pred_softmax_threshold",
Expand Down
6 changes: 3 additions & 3 deletions vamb/taxvamb_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,9 +824,6 @@ def __init__(
self.tree = _hloss.Hierarchy(table_parent)
self.n_tree_nodes = nlabels

if cuda:
self.cuda()

self.nodes = nodes
self.table_parent = table_parent
self.hierloss = init_hier_loss(hier_loss, self.tree)
Expand Down Expand Up @@ -855,6 +852,9 @@ def __init__(
self.relu = _nn.LeakyReLU()
self.dropoutlayer = _nn.Dropout(p=self.dropout)

if cuda:
self.cuda()

def _predict(self, tensor: Tensor) -> tuple[Tensor, Tensor]:
tensors: list[_torch.Tensor] = list()

Expand Down

0 comments on commit 852ee3c

Please sign in to comment.