diff --git a/tests/test_bcftools_validation.py b/tests/test_bcftools_validation.py index 9ad3f8b..d8813d6 100644 --- a/tests/test_bcftools_validation.py +++ b/tests/test_bcftools_validation.py @@ -131,10 +131,10 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file): (r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz"), (r"query -f '[%CHROM\t]\n'", "sample.vcf.gz"), (r"query -f '[%CHROM\t]\n' -i 'POS=112'", "sample.vcf.gz"), - (r"query -f '%CHROM\t%POS\t%REF\t%ALT[\t%GT]\n'", "sample.vcf.gz"), + (r"query -f '%CHROM\t%POS\t%REF\t%ALT[\t%SAMPLE=%GT]\n'", "sample.vcf.gz"), (r"query -f 'GQ:[ %GQ] \t GT:[ %GT]\n'", "sample.vcf.gz"), - (r"query -f '[%CHROM:%POS %GT\n]'", "sample.vcf.gz"), - (r"query -f '[%GT %DP\n]'", "sample.vcf.gz"), + (r"query -f '[%CHROM:%POS %SAMPLE %GT\n]'", "sample.vcf.gz"), + (r"query -f '[%SAMPLE %GT %DP\n]'", "sample.vcf.gz"), ], ) def test_output(tmp_path, args, vcf_name): diff --git a/vcztools/query.py b/vcztools/query.py index 9bf3613..fd33c4a 100644 --- a/vcztools/query.py +++ b/vcztools/query.py @@ -120,6 +120,14 @@ def stringify(gt_and_phase: tuple): return generate + def _compose_sample_ids_generator(self) -> Callable: + def generate(root): + variant_count = root["variant_position"].shape[0] + sample_ids = root["sample_id"][:].tolist() + yield from itertools.repeat(sample_ids, variant_count) + + return generate + def _compose_tag_generator( self, tag: str, *, subfield=False, sample_loop=False ) -> Callable: @@ -129,6 +137,9 @@ def _compose_tag_generator( if tag == "GT": return self._compose_gt_generator() + if tag == "SAMPLE": + return self._compose_sample_ids_generator() + def generate(root): vcz_names = set(name for name, _zarray in root.items()) vcz_name = vcf_name_to_vcz_name(vcz_names, tag)