diff --git a/tests/test_crf.py b/tests/test_crf.py index f61249c..ce0ac6a 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -34,18 +34,23 @@ def compute_score(crf, emission, tag): return score -def make_crf(num_tags=5, batch_first=False): - return CRF(num_tags, batch_first=batch_first) +def make_crf(num_tags=5, batch_first=False, device=None): + model = CRF(num_tags, batch_first=batch_first) + if device is not None: + model = model.to(device) + return model -def make_emissions(crf, seq_length=3, batch_size=2): +def make_emissions(crf, seq_length=3, batch_size=2, device=None): em = torch.randn(seq_length, batch_size, crf.num_tags) if crf.batch_first: em = em.transpose(0, 1) + if device is not None: + em = em.to(device) return em -def make_tags(crf, seq_length=3, batch_size=2): +def make_tags(crf, seq_length=3, batch_size=2, device=None): # shape: (seq_length, batch_size) ts = torch.tensor([[random.randrange(crf.num_tags) for b in range(batch_size)] @@ -53,6 +58,8 @@ def make_tags(crf, seq_length=3, batch_size=2): dtype=torch.long) if crf.batch_first: ts = ts.transpose(0, 1) + if device is not None: + ts = ts.to(device) return ts @@ -91,7 +98,7 @@ def test_works_with_mask(self): # shape: (seq_length, batch_size) tags = make_tags(crf, seq_length, batch_size) # mask should have size of (seq_length, batch_size) - mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool).transpose(0, 1) # shape: () llh = crf(emissions, tags, mask=mask) @@ -128,7 +135,7 @@ def test_works_without_mask(self): llh_no_mask = crf(emissions, tags) # No mask means the mask is all ones - llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).byte()) + llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).bool()) assert_close(llh_no_mask, llh_mask) @@ -229,7 +236,7 @@ def test_reduction_token_mean(self): # shape: (seq_length, batch_size) tags = make_tags(crf, seq_length, batch_size) # mask should have size of (seq_length, batch_size) - mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool).transpose(0, 1) llh = crf(emissions, tags, mask=mask, reduction='token_mean') @@ -314,7 +321,7 @@ def test_emissions_last_dimension_not_equal_to_number_of_tags(self): def test_first_timestep_mask_is_not_all_on(self): emissions = torch.randn(3, 2, 4) tags = torch.empty(3, 2, dtype=torch.long) - mask = torch.tensor([[1, 1, 1], [0, 0, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [0, 0, 0]], dtype=torch.bool).transpose(0, 1) crf = make_crf(4) with pytest.raises(ValueError) as excinfo: @@ -340,6 +347,314 @@ def test_invalid_reduction(self): assert 'invalid reduction: foo' in str(excinfo.value) +class TestForwardOnCuda: + + def test_works_with_mask(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + crf = make_crf(device=device) + seq_length, batch_size = 3, 2 + + # shape: (seq_length, batch_size, num_tags) + emissions = make_emissions(crf, seq_length, batch_size, device=device) + # shape: (seq_length, batch_size) + tags = make_tags(crf, seq_length, batch_size, device=device) + # mask should have size of (seq_length, batch_size) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool, device=device).transpose(0, 1) + + # shape: () + llh = crf(emissions, tags, mask=mask) + + # shape: (batch_size, seq_length, num_tags) + emissions = emissions.transpose(0, 1) + # shape: (batch_size, seq_length) + tags = tags.transpose(0, 1) + # shape: (batch_size, seq_length) + mask = mask.transpose(0, 1) + + # Compute log likelihood manually + manual_llh = 0. + for emission, tag, mask_ in zip(emissions, tags, mask): + seq_len = mask_.sum() + emission, tag = emission[:seq_len], tag[:seq_len] + numerator = compute_score(crf, emission, tag) + all_scores = [ + compute_score(crf, emission, t) + for t in itertools.product(range(crf.num_tags), repeat=seq_len) + ] + denominator = math.log(sum(math.exp(s) for s in all_scores)) + manual_llh += numerator - denominator + + assert_close(llh, manual_llh) + llh.backward() # ensure gradients can be computed + + def test_works_without_mask(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + crf = make_crf(device=device) + # shape: (seq_length, batch_size, num_tags) + emissions = make_emissions(crf, device=device) + # shape: (seq_length, batch_size) + tags = make_tags(crf, device=device) + + llh_no_mask = crf(emissions, tags) + # No mask means the mask is all ones + llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).bool()) + + assert_close(llh_no_mask, llh_mask) + + def test_batched_loss(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + crf = make_crf(device=device) + batch_size = 10 + + # shape: (seq_length, batch_size, num_tags) + emissions = make_emissions(crf, batch_size=batch_size, device=device) + # shape: (seq_length, batch_size) + tags = make_tags(crf, batch_size=batch_size, device=device) + + llh = crf(emissions, tags) + assert torch.is_tensor(llh) + assert llh.shape == () + + total_llh = 0. + for i in range(batch_size): + # shape: (seq_length, 1, num_tags) + emissions_ = emissions[:, i, :].unsqueeze(1) + # shape: (seq_length, 1) + tags_ = tags[:, i].unsqueeze(1) + # shape: () + total_llh += crf(emissions_, tags_) + + assert_close(llh, total_llh) + + def test_reduction_none(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + crf = make_crf(device=device) + # shape: (seq_length, batch_size, num_tags) + emissions = make_emissions(crf, device=device) + # shape: (seq_length, batch_size) + tags = make_tags(crf, device=device) + + seq_length, batch_size = tags.shape + + llh = crf(emissions, tags, reduction='none') + + assert torch.is_tensor(llh) + assert llh.shape == (batch_size,) + + # shape: (batch_size, seq_length, num_tags) + emissions = emissions.transpose(0, 1) + # shape: (batch_size, seq_length) + tags = tags.transpose(0, 1) + + # Compute log likelihood manually + manual_llh = [] + for emission, tag in zip(emissions, tags): + numerator = compute_score(crf, emission, tag) + all_scores = [ + compute_score(crf, emission, t) + for t in itertools.product(range(crf.num_tags), repeat=seq_length) + ] + denominator = math.log(sum(math.exp(s) for s in all_scores)) + manual_llh.append(numerator - denominator) + + assert_close(llh, torch.tensor(manual_llh, device=device)) + + def test_reduction_mean(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + crf = make_crf(device=device) + # shape: (seq_length, batch_size, num_tags) + emissions = make_emissions(crf, device=device) + # shape: (seq_length, batch_size) + tags = make_tags(crf, device=device) + + seq_length, batch_size = tags.shape + + llh = crf(emissions, tags, reduction='mean') + + assert torch.is_tensor(llh) + assert llh.shape == () + + # shape: (batch_size, seq_length, num_tags) + emissions = emissions.transpose(0, 1) + # shape: (batch_size, seq_length) + tags = tags.transpose(0, 1) + + # Compute log likelihood manually + manual_llh = 0 + for emission, tag in zip(emissions, tags): + numerator = compute_score(crf, emission, tag) + all_scores = [ + compute_score(crf, emission, t) + for t in itertools.product(range(crf.num_tags), repeat=seq_length) + ] + denominator = math.log(sum(math.exp(s) for s in all_scores)) + manual_llh += numerator - denominator + + assert_close(llh, manual_llh / batch_size) + + def test_reduction_token_mean(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + crf = make_crf(device=device) + seq_length, batch_size = 3, 2 + + # shape: (seq_length, batch_size, num_tags) + emissions = make_emissions(crf, seq_length, batch_size, device=device) + # shape: (seq_length, batch_size) + tags = make_tags(crf, seq_length, batch_size, device=device) + # mask should have size of (seq_length, batch_size) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool, device=device).transpose(0, 1) + + llh = crf(emissions, tags, mask=mask, reduction='token_mean') + + assert torch.is_tensor(llh) + assert llh.shape == () + + # shape: (batch_size, seq_length, num_tags) + emissions = emissions.transpose(0, 1) + # shape: (batch_size, seq_length) + tags = tags.transpose(0, 1) + # shape: (batch_size, seq_length) + mask = mask.transpose(0, 1) + + # Compute log likelihood manually + manual_llh, n_tokens = 0, 0 + for emission, tag, mask_ in zip(emissions, tags, mask): + seq_len = mask_.sum() + emission, tag = emission[:seq_len], tag[:seq_len] + numerator = compute_score(crf, emission, tag) + all_scores = [ + compute_score(crf, emission, t) + for t in itertools.product(range(crf.num_tags), repeat=seq_len) + ] + denominator = math.log(sum(math.exp(s) for s in all_scores)) + manual_llh += numerator - denominator + n_tokens += seq_len + + assert_close(llh, manual_llh / n_tokens) + + def test_batch_first(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + crf = make_crf(device=device) + # shape: (seq_length, batch_size, num_tags) + emissions = make_emissions(crf, device=device) + # shape: (seq_length, batch_size) + tags = make_tags(crf, device=device) + llh = crf(emissions, tags) + + crf_bf = make_crf(batch_first=True, device=device) + # Copy parameter values from non-batch-first CRF; requires_grad must be False + # to avoid runtime error of in-place operation on a leaf variable + crf_bf.start_transitions.requires_grad_(False).copy_(crf.start_transitions) + crf_bf.end_transitions.requires_grad_(False).copy_(crf.end_transitions) + crf_bf.transitions.requires_grad_(False).copy_(crf.transitions) + + # shape: (batch_size, seq_length, num_tags) + emissions = emissions.transpose(0, 1) + # shape: (batch_size, seq_length) + tags = tags.transpose(0, 1) + llh_bf = crf_bf(emissions, tags) + + assert_close(llh, llh_bf) + + def test_emissions_has_bad_number_of_dimension(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + emissions = torch.randn(1, 2).to(device) + tags = torch.empty(2, 2, dtype=torch.long, device=device) + crf = make_crf(device=device) + + with pytest.raises(ValueError) as excinfo: + crf(emissions, tags) + assert 'emissions must have dimension of 3, got 2' in str(excinfo.value) + + def test_emissions_and_tags_size_mismatch(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + emissions = torch.randn(1, 2, 3).to(device) + tags = torch.empty(2, 2, dtype=torch.long, device=device) + crf = make_crf(3, device=device) + + with pytest.raises(ValueError) as excinfo: + crf(emissions, tags) + assert ( + 'the first two dimensions of emissions and tags must match, ' + 'got (1, 2) and (2, 2)') in str(excinfo.value) + + def test_emissions_last_dimension_not_equal_to_number_of_tags(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + emissions = torch.randn(1, 2, 3).to(device) + tags = torch.empty(1, 2, dtype=torch.long, device=device) + crf = make_crf(10, device=device) + + with pytest.raises(ValueError) as excinfo: + crf(emissions, tags) + assert 'expected last dimension of emissions is 10, got 3' in str(excinfo.value) + + def test_first_timestep_mask_is_not_all_on(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + emissions = torch.randn(3, 2, 4).to(device) + tags = torch.empty(3, 2, dtype=torch.long, device=device) + mask = torch.tensor([[1, 1, 1], [0, 0, 0]], dtype=torch.bool, device=device).transpose(0, 1) + crf = make_crf(4, device=device) + + with pytest.raises(ValueError) as excinfo: + crf(emissions, tags, mask=mask) + assert 'mask of the first timestep must all be on' in str(excinfo.value) + + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + mask = mask.transpose(0, 1) + crf = make_crf(4, batch_first=True) + + with pytest.raises(ValueError) as excinfo: + crf(emissions, tags, mask=mask) + assert 'mask of the first timestep must all be on' in str(excinfo.value) + + def test_invalid_reduction(self): + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + raise RuntimeError("No GPU with CUDA to test on.") + crf = make_crf(device=device) + emissions = make_emissions(crf, device=device) + tags = make_tags(crf, device=device) + + with pytest.raises(ValueError) as excinfo: + crf(emissions, tags, reduction='foo') + assert 'invalid reduction: foo' in str(excinfo.value) + + class TestDecode: def test_works_with_mask(self): crf = make_crf() @@ -348,7 +663,7 @@ def test_works_with_mask(self): # shape: (seq_length, batch_size, num_tags) emissions = make_emissions(crf, seq_length, batch_size) # mask should be (seq_length, batch_size) - mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool).transpose(0, 1) best_tags = crf.decode(emissions, mask=mask) @@ -376,7 +691,7 @@ def test_works_without_mask(self): best_tags_no_mask = crf.decode(emissions) # No mask means mask is all ones best_tags_mask = crf.decode( - emissions, mask=emissions.new_ones(emissions.shape[:2]).byte()) + emissions, mask=emissions.new_ones(emissions.shape[:2]).bool()) assert best_tags_no_mask == best_tags_mask @@ -387,7 +702,7 @@ def test_batched_decode(self): # shape: (seq_length, batch_size, num_tags) emissions = make_emissions(crf, seq_length, batch_size) # shape: (seq_length, batch_size) - mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool).transpose(0, 1) batched = crf.decode(emissions, mask=mask) @@ -441,7 +756,7 @@ def test_emissions_last_dimension_not_equal_to_number_of_tags(self): def test_emissions_and_mask_size_mismatch(self): emissions = torch.randn(1, 2, 3) - mask = torch.tensor([[1, 1], [1, 0]], dtype=torch.uint8) + mask = torch.tensor([[1, 1], [1, 0]], dtype=torch.bool) crf = make_crf(3) with pytest.raises(ValueError) as excinfo: @@ -452,7 +767,7 @@ def test_emissions_and_mask_size_mismatch(self): def test_first_timestep_mask_is_not_all_on(self): emissions = torch.randn(3, 2, 4) - mask = torch.tensor([[1, 1, 1], [0, 0, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [0, 0, 0]], dtype=torch.bool).transpose(0, 1) crf = make_crf(4) with pytest.raises(ValueError) as excinfo: @@ -498,7 +813,7 @@ def test_default_forward(self): # shape: (seq_length, batch_size) tags = make_tags(crf, seq_length, batch_size) # mask should have size of (seq_length, batch_size) - mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool).transpose(0, 1) llh = crf(emissions, tags, mask=mask) llh_scripted = crf_script(emissions, tags, mask=mask) assert_close(llh_scripted, llh) @@ -526,8 +841,8 @@ def test_all_ones_mask(self): tags = make_tags(crf, seq_length, batch_size) # No mask means the mask is all ones - llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).byte()) - llh_mask_script = crf_script(emissions, tags, mask=torch.ones_like(tags).byte()) + llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).bool()) + llh_mask_script = crf_script(emissions, tags, mask=torch.ones_like(tags).bool()) assert_close(llh_mask_script, llh_mask) def test_batched_forward(self): @@ -578,7 +893,7 @@ def test_reduction_token_mean(self): # shape: (seq_length, batch_size) tags = make_tags(crf) - mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool).transpose(0, 1) llh = crf(emissions, tags, mask=mask, reduction='token_mean') llh_script = crf_script(emissions, tags, mask=mask, reduction='token_mean') assert_close(llh_script, llh) @@ -620,7 +935,7 @@ def test_with_mask(self): # shape: (seq_length, batch_size, num_tags) emissions = make_emissions(crf, seq_length, batch_size) # mask should be (seq_length, batch_size) - mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1) + mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool).transpose(0, 1) best_tags = crf.decode(emissions, mask=mask) best_tags_scripted = crf_script.decode(emissions, mask=mask) assert best_tags == best_tags_scripted diff --git a/torchcrf/__init__.py b/torchcrf/__init__.py index 6ff18e4..af919fd 100644 --- a/torchcrf/__init__.py +++ b/torchcrf/__init__.py @@ -64,7 +64,7 @@ def forward( self, emissions: torch.Tensor, tags: torch.LongTensor, - mask: Optional[torch.ByteTensor] = None, + mask: Optional[torch.BoolTensor] = None, reduction: str = 'sum', ) -> torch.Tensor: """Compute the conditional log likelihood of a sequence of tags given emission scores. @@ -76,8 +76,9 @@ def forward( tags (`~torch.LongTensor`): Sequence of tags tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. - mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + mask (`~torch.BoolTensor`): Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + The True values are unmasked, the False values are masked. reduction: Specifies the reduction to apply to the output: ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. ``sum``: the output will be summed over batches. ``mean``: the output will be @@ -91,7 +92,7 @@ def forward( if reduction not in ('none', 'sum', 'mean', 'token_mean'): raise ValueError(f'invalid reduction: {reduction}') if mask is None: - mask = torch.ones_like(tags, dtype=torch.uint8) + mask = torch.ones_like(tags, dtype=torch.bool) if self.batch_first: emissions = emissions.transpose(0, 1) @@ -131,7 +132,7 @@ def decode(self, emissions: torch.Tensor, """ self._validate(emissions, mask=mask) if mask is None: - mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) + mask = emissions.new_ones(emissions.shape[:2], dtype=torch.bool) if self.batch_first: emissions = emissions.transpose(0, 1) @@ -143,7 +144,7 @@ def _validate( self, emissions: torch.Tensor, tags: Optional[torch.LongTensor] = None, - mask: Optional[torch.ByteTensor] = None) -> None: + mask: Optional[torch.BoolTensor] = None) -> None: if emissions.dim() != 3: raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') if emissions.size(2) != self.num_tags: @@ -171,7 +172,7 @@ def _validate( def _compute_score( self, emissions: torch.Tensor, tags: torch.LongTensor, - mask: torch.ByteTensor) -> torch.Tensor: + mask: torch.BoolTensor) -> torch.Tensor: # emissions: (seq_length, batch_size, num_tags) # tags: (seq_length, batch_size) # mask: (seq_length, batch_size) @@ -182,34 +183,37 @@ def _compute_score( assert mask[0].all() seq_length, batch_size = tags.shape - mask = mask.type_as(emissions) + + tags_range = torch.arange(self.num_tags, device=tags.device) + batch_range = torch.arange(batch_size, device=tags.device) # Start transition score and first emission # shape: (batch_size,) score = self.start_transitions[tags[0]] - score += emissions[0, torch.arange(batch_size), tags[0]] - - for i in range(1, seq_length): - # Transition score to next tag, only added if next timestep is valid (mask == 1) - # shape: (batch_size,) - score += self.transitions[tags[i - 1], tags[i]] * mask[i] - - # Emission score for next tag, only added if next timestep is valid (mask == 1) - # shape: (batch_size,) - score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] # End transition score # shape: (batch_size,) - seq_ends = mask.long().sum(dim=0) - 1 + seq_ends = mask.sum(dim=0) - 1 # shape: (batch_size,) - last_tags = tags[seq_ends, torch.arange(batch_size)] + last_tags = tags[seq_ends, batch_range] # shape: (batch_size,) score += self.end_transitions[last_tags] + emission_scores = emissions[tags_range == tags.unsqueeze(2)] + emission_scores = emission_scores.view(seq_length, batch_size) + + tag_transitions = tags.unfold(dimension=0, size=2, step=1).contiguous() + tag_transitions = tag_transitions.view((seq_length - 1) * batch_size, 2) + tag_from, tag_to = tag_transitions[:, 0], tag_transitions[:, 1] + transition_scores = self.transitions[tag_from, tag_to].view(seq_length - 1, batch_size) + + emission_scores[1:, :] += transition_scores + score = score + (emission_scores * mask).sum(0) + return score def _compute_normalizer( - self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: + self, emissions: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 @@ -260,7 +264,7 @@ def _compute_normalizer( return torch.logsumexp(score, dim=1) def _viterbi_decode(self, emissions: torch.FloatTensor, - mask: torch.ByteTensor) -> List[List[int]]: + mask: torch.BoolTensor) -> List[List[int]]: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 @@ -315,7 +319,7 @@ def _viterbi_decode(self, emissions: torch.FloatTensor, # Now, compute the best path for each sample # shape: (batch_size,) - seq_ends = mask.long().sum(dim=0) - 1 + seq_ends = mask.sum(dim=0) - 1 best_tags_list: List[List[int]] = [] for idx in range(batch_size):