Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In-code documentation update #260

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@


def read_file(blobpath: str) -> bytes:
"""
Reads the contents of a file specified by the given blobpath.

Parameters
----------
blobpath : str
The path or URL to the file to be read.

Returns
-------
bytes
The binary content of the file.
"""
if not blobpath.startswith("http://") and not blobpath.startswith("https://"):
try:
import blobfile
Expand All @@ -28,11 +41,44 @@ def read_file(blobpath: str) -> bytes:


def check_hash(data: bytes, expected_hash: str) -> bool:
"""
Checks if the hash of the given data matches the expected hash.

Parameters
----------
data : bytes
The binary data to be hashed.

expected_hash : str
The expected hash value.

Returns
-------
bool
True if the actual hash matches the expected hash, False otherwise.
"""
actual_hash = hashlib.sha256(data).hexdigest()
return actual_hash == expected_hash


def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes:
"""
Reads the contents of a file specified by the given blobpath from cache if available,
otherwise fetches it from the source, caches it, and returns the content.

Parameters
----------
blobpath : str
The path or URL to the file to be read.

expected_hash : str, optional
The expected hash value of the file content. Default is None.

Returns
-------
bytes
The binary content of the file.
"""
user_specified_cache = True
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
Expand Down Expand Up @@ -88,6 +134,28 @@ def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_hash: Optional[str] = None,
encoder_json_hash: Optional[str] = None,
) -> dict[bytes, int]:
"""
Converts a vocab BPE file and an encoder JSON file into mergeable BPE ranks.

Parameters
----------
vocab_bpe_file : str
The path to the vocabulary BPE file.

encoder_json_file : str
The path to the encoder JSON file.

vocab_bpe_hash : str, optional
The expected hash value of the vocabulary BPE file. Default is None.

encoder_json_hash : str, optional
The expected hash value of the encoder JSON file. Default is None.

Returns
-------
dict[bytes, int]
A dictionary mapping mergeable BPE tokens to their ranks.
"""
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]

Expand Down Expand Up @@ -129,6 +197,21 @@ def decode_data_gym(value: str) -> bytes:


def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None:
"""
Dumps the mergeable BPE ranks to a TikToken BPE file.

Parameters
----------
bpe_ranks : dict[bytes, int]
A dictionary mapping mergeable BPE tokens to their ranks.

tiktoken_bpe_file : str
The path to the TikToken BPE file.

Returns
-------
None
"""
try:
import blobfile
except ImportError as e:
Expand All @@ -143,6 +226,22 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
def load_tiktoken_bpe(
tiktoken_bpe_file: str, expected_hash: Optional[str] = None
) -> dict[bytes, int]:
"""
Loads mergeable BPE ranks from a TikToken BPE file.

Parameters
----------
tiktoken_bpe_file : str
The path to the TikToken BPE file.

expected_hash : str, optional
The expected hash value of the file content. Default is None.

Returns
-------
dict[bytes, int]
A dictionary mapping mergeable BPE tokens to their ranks.
"""
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
return {
Expand Down
57 changes: 53 additions & 4 deletions tiktoken/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,44 @@


def encoding_name_for_model(model_name: str) -> str:
"""Returns the name of the encoding used by a model.
"""
Returns the name of the encoding used by a model.

Parameters
----------
model_name : str
The name of the model.

Returns
-------
encoding_name : str
The name of the encoding used by the model.

Raises
------
KeyError
If the model name is not recognized or cannot be mapped to an encoding.

Notes
-----
This function checks if the provided model name is directly mapped to an encoding in MODEL_TO_ENCODING.
If not, it attempts to match the model name with known prefixes in MODEL_PREFIX_TO_ENCODING.
If a match is found, it returns the corresponding encoding name.

If the model name cannot be mapped to any encoding, it raises a KeyError.

Raises a KeyError if the model name is not recognised.
Examples
--------
>>> encoding_name_for_model("gpt2")
'gpt2'

>>> encoding_name_for_model("roberta-large")
'roberta'

>>> encoding_name_for_model("nonexistent-model")
Traceback (most recent call last):
...
KeyError: "Could not automatically map nonexistent-model to a tokeniser. Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect."
"""
encoding_name = None
if model_name in MODEL_TO_ENCODING:
Expand All @@ -94,8 +129,22 @@ def encoding_name_for_model(model_name: str) -> str:


def encoding_for_model(model_name: str) -> Encoding:
"""Returns the encoding used by a model.
"""
Returns the encoding used by a model.

Parameters
----------
model_name : str
The name of the model.

Returns
-------
encoding : Encoding
The encoding used by the model.

Raises a KeyError if the model name is not recognised.
Raises
------
KeyError
If the model name is not recognized or cannot be mapped to an encoding.
"""
return get_encoding(encoding_name_for_model(model_name))
54 changes: 54 additions & 0 deletions tiktoken/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@

@functools.lru_cache()
def _available_plugin_modules() -> Sequence[str]:
"""
Returns a sequence of available plugin modules.

Returns
-------
Sequence[str]
A sequence of available plugin modules.
"""
# tiktoken_ext is a namespace package
# submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes
# - we use namespace package pattern so `pkgutil.iter_modules` is fast
Expand All @@ -30,6 +38,22 @@ def _available_plugin_modules() -> Sequence[str]:


def _find_constructors() -> None:
"""
Finds encoding constructors from available plugin modules and populates the ENCODING_CONSTRUCTORS dictionary.

Parameters
----------
None

Returns
-------
None

Raises
------
ValueError
If a plugin module does not define ENCODING_CONSTRUCTORS or if there are duplicate encoding names.
"""
global ENCODING_CONSTRUCTORS
with _lock:
if ENCODING_CONSTRUCTORS is not None:
Expand All @@ -53,6 +77,24 @@ def _find_constructors() -> None:


def get_encoding(encoding_name: str) -> Encoding:
"""
Retrieves an Encoding object for the specified encoding name.

Parameters
----------
encoding_name : str
The name of the encoding.

Returns
-------
Encoding
The Encoding object for the specified encoding name.

Raises
------
ValueError
If the specified encoding name is unknown.
"""
if encoding_name in ENCODINGS:
return ENCODINGS[encoding_name]

Expand All @@ -76,6 +118,18 @@ def get_encoding(encoding_name: str) -> Encoding:


def list_encoding_names() -> list[str]:
"""
Lists available encoding names.

Parameters
----------
None

Returns
-------
list[str]
A list of available encoding names.
"""
with _lock:
if ENCODING_CONSTRUCTORS is None:
_find_constructors()
Expand Down