Skip to content

Commit

Permalink
gptbigcode forward type fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Davis Wertheimer <[email protected]>
  • Loading branch information
daviswer committed Oct 10, 2024
1 parent 3adfb7d commit cf93f60
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions speculator/train_speculator_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import re
import time
from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -437,11 +437,12 @@ class EmbedGPTBigCode(GPTBigCode):
# Overrides the forward function of GPTBigCode to allow returning embedding vectors
def forward(
self,
x: torch.LongTensor,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value_states: Optional[List[Tuple[torch.Tensor,]]] = None,
use_cache: bool = False,
only_last_token: bool = False,
attn_algorithm: Optional[str] = None,
include_embeds: bool = False,
):
Expand Down

0 comments on commit cf93f60

Please sign in to comment.