From 45ee4e32c5d2b428a81323b88b45e23a50727120 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20Brul=C3=A9=20Naudet?= Date: Mon, 19 Feb 2024 19:59:34 +0100 Subject: [PATCH] In-code documentation and update MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dear Developers, I'm pleased to inform you that I have completed the documentation update the load, model and registry files. The updated documentation provides clear explanations of function parameters, return types, and expected behavior. Additionally, it adheres to consistent formatting and organization, ensuring ease of understanding for both current and future developers. Please review the updated documentation at your earliest convenience. If you have any feedback or suggestions for further improvements, please don't hesitate to let me know. Thank you for your attention to this matter. Best regards, Louis Brulé Naudet --- tiktoken/load.py | 99 ++++++++++++++++++++++++++++++++++++++++++++ tiktoken/model.py | 57 +++++++++++++++++++++++-- tiktoken/registry.py | 54 ++++++++++++++++++++++++ 3 files changed, 206 insertions(+), 4 deletions(-) diff --git a/tiktoken/load.py b/tiktoken/load.py index cc0a6a6d..11032dfd 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -12,6 +12,19 @@ def read_file(blobpath: str) -> bytes: + """ + Reads the contents of a file specified by the given blobpath. + + Parameters + ---------- + blobpath : str + The path or URL to the file to be read. + + Returns + ------- + bytes + The binary content of the file. + """ if not blobpath.startswith("http://") and not blobpath.startswith("https://"): try: import blobfile @@ -28,11 +41,44 @@ def read_file(blobpath: str) -> bytes: def check_hash(data: bytes, expected_hash: str) -> bool: + """ + Checks if the hash of the given data matches the expected hash. + + Parameters + ---------- + data : bytes + The binary data to be hashed. + + expected_hash : str + The expected hash value. + + Returns + ------- + bool + True if the actual hash matches the expected hash, False otherwise. + """ actual_hash = hashlib.sha256(data).hexdigest() return actual_hash == expected_hash def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes: + """ + Reads the contents of a file specified by the given blobpath from cache if available, + otherwise fetches it from the source, caches it, and returns the content. + + Parameters + ---------- + blobpath : str + The path or URL to the file to be read. + + expected_hash : str, optional + The expected hash value of the file content. Default is None. + + Returns + ------- + bytes + The binary content of the file. + """ user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -88,6 +134,28 @@ def data_gym_to_mergeable_bpe_ranks( vocab_bpe_hash: Optional[str] = None, encoder_json_hash: Optional[str] = None, ) -> dict[bytes, int]: + """ + Converts a vocab BPE file and an encoder JSON file into mergeable BPE ranks. + + Parameters + ---------- + vocab_bpe_file : str + The path to the vocabulary BPE file. + + encoder_json_file : str + The path to the encoder JSON file. + + vocab_bpe_hash : str, optional + The expected hash value of the vocabulary BPE file. Default is None. + + encoder_json_hash : str, optional + The expected hash value of the encoder JSON file. Default is None. + + Returns + ------- + dict[bytes, int] + A dictionary mapping mergeable BPE tokens to their ranks. + """ # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -129,6 +197,21 @@ def decode_data_gym(value: str) -> bytes: def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None: + """ + Dumps the mergeable BPE ranks to a TikToken BPE file. + + Parameters + ---------- + bpe_ranks : dict[bytes, int] + A dictionary mapping mergeable BPE tokens to their ranks. + + tiktoken_bpe_file : str + The path to the TikToken BPE file. + + Returns + ------- + None + """ try: import blobfile except ImportError as e: @@ -143,6 +226,22 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No def load_tiktoken_bpe( tiktoken_bpe_file: str, expected_hash: Optional[str] = None ) -> dict[bytes, int]: + """ + Loads mergeable BPE ranks from a TikToken BPE file. + + Parameters + ---------- + tiktoken_bpe_file : str + The path to the TikToken BPE file. + + expected_hash : str, optional + The expected hash value of the file content. Default is None. + + Returns + ------- + dict[bytes, int] + A dictionary mapping mergeable BPE tokens to their ranks. + """ # NB: do not add caching to this function contents = read_file_cached(tiktoken_bpe_file, expected_hash) return { diff --git a/tiktoken/model.py b/tiktoken/model.py index 17532aee..6711e00e 100644 --- a/tiktoken/model.py +++ b/tiktoken/model.py @@ -69,9 +69,44 @@ def encoding_name_for_model(model_name: str) -> str: - """Returns the name of the encoding used by a model. + """ + Returns the name of the encoding used by a model. + + Parameters + ---------- + model_name : str + The name of the model. + + Returns + ------- + encoding_name : str + The name of the encoding used by the model. + + Raises + ------ + KeyError + If the model name is not recognized or cannot be mapped to an encoding. + + Notes + ----- + This function checks if the provided model name is directly mapped to an encoding in MODEL_TO_ENCODING. + If not, it attempts to match the model name with known prefixes in MODEL_PREFIX_TO_ENCODING. + If a match is found, it returns the corresponding encoding name. + + If the model name cannot be mapped to any encoding, it raises a KeyError. - Raises a KeyError if the model name is not recognised. + Examples + -------- + >>> encoding_name_for_model("gpt2") + 'gpt2' + + >>> encoding_name_for_model("roberta-large") + 'roberta' + + >>> encoding_name_for_model("nonexistent-model") + Traceback (most recent call last): + ... + KeyError: "Could not automatically map nonexistent-model to a tokeniser. Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect." """ encoding_name = None if model_name in MODEL_TO_ENCODING: @@ -94,8 +129,22 @@ def encoding_name_for_model(model_name: str) -> str: def encoding_for_model(model_name: str) -> Encoding: - """Returns the encoding used by a model. + """ + Returns the encoding used by a model. + + Parameters + ---------- + model_name : str + The name of the model. + + Returns + ------- + encoding : Encoding + The encoding used by the model. - Raises a KeyError if the model name is not recognised. + Raises + ------ + KeyError + If the model name is not recognized or cannot be mapped to an encoding. """ return get_encoding(encoding_name_for_model(model_name)) diff --git a/tiktoken/registry.py b/tiktoken/registry.py index a753ce67..6d06f449 100644 --- a/tiktoken/registry.py +++ b/tiktoken/registry.py @@ -17,6 +17,14 @@ @functools.lru_cache() def _available_plugin_modules() -> Sequence[str]: + """ + Returns a sequence of available plugin modules. + + Returns + ------- + Sequence[str] + A sequence of available plugin modules. + """ # tiktoken_ext is a namespace package # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes # - we use namespace package pattern so `pkgutil.iter_modules` is fast @@ -30,6 +38,22 @@ def _available_plugin_modules() -> Sequence[str]: def _find_constructors() -> None: + """ + Finds encoding constructors from available plugin modules and populates the ENCODING_CONSTRUCTORS dictionary. + + Parameters + ---------- + None + + Returns + ------- + None + + Raises + ------ + ValueError + If a plugin module does not define ENCODING_CONSTRUCTORS or if there are duplicate encoding names. + """ global ENCODING_CONSTRUCTORS with _lock: if ENCODING_CONSTRUCTORS is not None: @@ -53,6 +77,24 @@ def _find_constructors() -> None: def get_encoding(encoding_name: str) -> Encoding: + """ + Retrieves an Encoding object for the specified encoding name. + + Parameters + ---------- + encoding_name : str + The name of the encoding. + + Returns + ------- + Encoding + The Encoding object for the specified encoding name. + + Raises + ------ + ValueError + If the specified encoding name is unknown. + """ if encoding_name in ENCODINGS: return ENCODINGS[encoding_name] @@ -76,6 +118,18 @@ def get_encoding(encoding_name: str) -> Encoding: def list_encoding_names() -> list[str]: + """ + Lists available encoding names. + + Parameters + ---------- + None + + Returns + ------- + list[str] + A list of available encoding names. + """ with _lock: if ENCODING_CONSTRUCTORS is None: _find_constructors()