diff --git a/pysam/libcbcf.pyx b/pysam/libcbcf.pyx index 8ecfe5f3..c84ef90d 100644 --- a/pysam/libcbcf.pyx +++ b/pysam/libcbcf.pyx @@ -239,29 +239,57 @@ cdef inline int is_gt_fmt(bcf_hdr_t *hdr, int fmt_id): cdef inline int bcf_genotype_count(bcf_hdr_t *hdr, bcf1_t *rec, int sample) except -1: - if sample < 0: raise ValueError('genotype is only valid as a format field') - cdef int32_t *gt_arr = NULL - cdef int ngt = 0 - ngt = bcf_get_genotypes(hdr, rec, >_arr, &ngt) + cdef int ploidy = bcf_get_ploidy(hdr, rec, sample) + return bcf_geno_combinations(ploidy, rec.n_allele) + +cdef inline int bcf_get_ploidy(bcf_hdr_t *hdr, bcf1_t *rec, int sample) except -1: + # compute sample ploidy by counting number of values in GT field + cdef int32_t n = rec.n_sample + + if bcf_unpack(rec, BCF_UN_ALL) < 0: + raise ValueError('Error unpacking VariantRecord') - if ngt <= 0 or not gt_arr: + if sample < 0 or sample >= n or not rec.n_fmt: return 0 - assert ngt % rec.n_sample == 0 - cdef int max_ploidy = ngt // rec.n_sample - cdef int32_t *gt = gt_arr + sample * max_ploidy - cdef int ploidy = 0 + cdef bcf_fmt_t *fmt0 = rec.d.fmt + cdef int gt0 = is_gt_fmt(hdr, fmt0.id) - while ploidy < max_ploidy and gt[0] != bcf_int32_vector_end: - gt += 1 - ploidy += 1 + if not gt0 or not fmt0.n: + return 0 - free(gt_arr) + cdef int ploidy = 0 + cdef int8_t *data8 + cdef int16_t *data16 + cdef int32_t *data32 + cdef int32_t a, nalleles = rec.n_allele - return bcf_geno_combinations(ploidy, rec.n_allele) + if fmt0.type == BCF_BT_INT8: + data8 = (fmt0.p + sample * fmt0.size) + for i in range(fmt0.n): + if data8[i] == bcf_int8_vector_end: + break + else: + ploidy += 1 + elif fmt0.type == BCF_BT_INT16: + data16 = (fmt0.p + sample * fmt0.size) + for i in range(fmt0.n): + if data16[i] == bcf_int16_vector_end: + break + else: + ploidy += 1 + elif fmt0.type == BCF_BT_INT32: + data32 = (fmt0.p + sample * fmt0.size) + for i in range(fmt0.n): + if data32[i] == bcf_int32_vector_end: + break + else: + ploidy += 1 + + return ploidy cdef tuple char_array_to_tuple(const char **a, ssize_t n, int free_after=0): diff --git a/tests/VariantRecordPL_bench.py b/tests/VariantRecordPL_bench.py new file mode 100644 index 00000000..65188daa --- /dev/null +++ b/tests/VariantRecordPL_bench.py @@ -0,0 +1,74 @@ +import pytest + +from pysam import VariantFile, VariantHeader + + +@pytest.mark.benchmark(min_rounds=1) +def test_access_pl_values_10_samples(benchmark, vcf_with_n_samples): + vcf = vcf_with_n_samples(10) + result = benchmark(access_pl_values, vcf) + assert result == (0, 1, 2) + + +@pytest.mark.benchmark(min_rounds=1) +def test_access_pl_values_100_samples(benchmark, vcf_with_n_samples): + vcf = vcf_with_n_samples(100) + result = benchmark(access_pl_values, vcf) + assert result == (0, 1, 2) + + +@pytest.mark.benchmark(min_rounds=1) +def test_access_pl_values_1000_samples(benchmark, vcf_with_n_samples): + vcf = vcf_with_n_samples(1000) + result = benchmark(access_pl_values, vcf) + assert result == (0, 1, 2) + + +@pytest.mark.benchmark(min_rounds=1) +def test_access_pl_values_10000_samples(benchmark, vcf_with_n_samples): + vcf = vcf_with_n_samples(10000) + result = benchmark(access_pl_values, vcf) + assert result == (0, 1, 2) + + +@pytest.mark.benchmark(min_rounds=1) +def test_access_pl_values_100000_samples(benchmark, vcf_with_n_samples): + vcf = vcf_with_n_samples(100000) + result = benchmark(access_pl_values, vcf) + assert result == (0, 1, 2) + + +def access_pl_values(data): + pl = None + with VariantFile(data) as vcf: + for record in vcf: + for sample in record.samples.values(): + pl = sample["PL"] + return pl + + +@pytest.fixture() +def vcf_with_n_samples(tmp_path): + def vcf_with_n_samples(n_samples): + vcfh = VariantHeader() + for s in (f"s{i}" for i in range(n_samples)): + vcfh.add_sample(s) + vcfh.add_meta("contig", items=[("ID", "chr20")]) + vcfh.add_meta( + "FORMAT", + items=dict(ID="PL", Number="G", Type="Integer", Description="Phred Scaled Likelihood").items(), + ) + vcfh.add_meta( + "FORMAT", + items=dict(ID="GT", Number="1", Type="String", Description="True Genotype").items(), + ) + + vcf = tmp_path / "large.vcf.gz" + with VariantFile(vcf, "w", header=vcfh) as vcf_out: + r = vcf_out.new_record(contig="chr20", start=1, stop=1, alleles=["C", "T"]) + for n, samp in enumerate(r.samples.values()): + samp["GT"] = (0, 0) + samp["PL"] = [0, 1, 2] + vcf_out.write(r) + return vcf + return vcf_with_n_samples