Skip to content

Commit

Permalink
Support both Path and str for APIs (#11865)
Browse files Browse the repository at this point in the history
* support both path and str for APIs

Signed-off-by: Maanu Grover <[email protected]>

* cleanup

Signed-off-by: Maanu Grover <[email protected]>

* fix cleanup

Signed-off-by: Maanu Grover <[email protected]>

---------

Signed-off-by: Maanu Grover <[email protected]>
  • Loading branch information
maanug-nv authored Jan 18, 2025
1 parent ad807ae commit 8bf5144
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import warnings
from copy import deepcopy
from pathlib import Path
Expand Down Expand Up @@ -47,6 +46,7 @@


TokenizerType = Any
AnyPath = Union[Path, str]


@run.cli.entrypoint(namespace="llm")
Expand Down Expand Up @@ -322,14 +322,14 @@ def ptq(

@run.cli.entrypoint(namespace="llm")
def deploy(
nemo_checkpoint: Path = None,
nemo_checkpoint: AnyPath = None,
model_type: str = "llama",
triton_model_name: str = "triton_model",
triton_model_version: Optional[int] = 1,
triton_http_port: int = 8000,
triton_grpc_port: int = 8001,
triton_http_address: str = "0.0.0.0",
triton_model_repository: Path = None,
triton_model_repository: AnyPath = None,
num_gpus: int = 1,
tensor_parallelism_size: int = 1,
pipeline_parallelism_size: int = 1,
Expand Down Expand Up @@ -376,6 +376,11 @@ def deploy(

unset_environment_variables()

if not isinstance(nemo_checkpoint, Path):
nemo_checkpoint = Path(nemo_checkpoint)
if not isinstance(triton_model_repository, Path):
triton_model_repository = Path(triton_model_repository)

triton_deployable = get_trtllm_deployable(
nemo_checkpoint,
model_type,
Expand Down Expand Up @@ -421,7 +426,7 @@ def deploy(


def evaluate(
nemo_checkpoint_path: Path,
nemo_checkpoint_path: AnyPath,
url: str = "grpc://0.0.0.0:8001",
triton_http_port: int = 8000,
model_name: str = "triton_model",
Expand All @@ -442,7 +447,8 @@ def evaluate(
Args:
nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt
which is required to tokenize the evaluation input and output prompts.
url (str): grpc service url that were used in the deploy method above in the format: grpc://{grpc_service_ip}:{grpc_port}.
url (str): grpc service url that were used in the deploy method above
in the format: grpc://{grpc_service_ip}:{grpc_port}.
triton_http_port (int): HTTP port that was used for the PyTriton server in the deploy method. Default: 8000.
Please pass the triton_http_port if using a custom port in the deploy method.
model_name (str): Name of the model that is deployed on PyTriton server. It should be the same as
Expand Down Expand Up @@ -478,6 +484,9 @@ def evaluate(

from nemo.collections.llm import evaluation

if not isinstance(nemo_checkpoint_path, Path):
nemo_checkpoint_path = Path(nemo_checkpoint_path)

# Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt.
tokenizer = io.load_context(nemo_checkpoint_path + "/context", subpath="model.tokenizer")
# Wait for server to be ready before starting evaluation
Expand All @@ -499,7 +508,7 @@ def evaluate(
def import_ckpt(
model: pl.LightningModule,
source: str,
output_path: Optional[Path] = None,
output_path: Optional[AnyPath] = None,
overwrite: bool = False,
) -> Path:
"""
Expand Down Expand Up @@ -557,6 +566,9 @@ def import_ckpt(
ValueError: If the model does not implement ConnectorMixin, indicating a lack of
necessary importer functionality.
"""
if output_path and not isinstance(output_path, Path):
output_path = Path(output_path)

output = io.import_ckpt(model=model, source=source, output_path=output_path, overwrite=overwrite)

console = Console()
Expand All @@ -569,15 +581,17 @@ def import_ckpt(
return output


def load_connector_from_trainer_ckpt(path: Path, target: str) -> io.ModelConnector:
def load_connector_from_trainer_ckpt(path: AnyPath, target: str) -> io.ModelConnector:
if not isinstance(path, Path):
path = Path(path)
return io.load_context(path, subpath="model").exporter(target, path)


@run.cli.entrypoint(name="export", namespace="llm")
def export_ckpt(
path: Path,
path: AnyPath,
target: str,
output_path: Optional[Path] = None,
output_path: Optional[AnyPath] = None,
overwrite: bool = False,
load_connector: Callable[[Path, str], io.ModelConnector] = load_connector_from_trainer_ckpt,
) -> Path:
Expand Down Expand Up @@ -628,6 +642,11 @@ def export_ckpt(
ValueError: If the model does not implement ConnectorMixin, indicating a lack of
necessary exporter functionality.
"""
if not isinstance(path, Path):
path = Path(path)
if output_path and not isinstance(output_path, Path):
output_path = Path(output_path)

output = io.export_ckpt(path, target, output_path, overwrite, load_connector)

console = Console()
Expand All @@ -638,7 +657,7 @@ def export_ckpt(

@run.cli.entrypoint(name="generate", namespace="llm")
def generate(
path: Union[Path, str],
path: AnyPath,
trainer: nl.Trainer,
prompts: Optional[list[str]] = None,
encoder_prompts: Optional[list[str]] = None,
Expand All @@ -650,7 +669,7 @@ def generate(
inference_batch_times_seqlen_threshold: int = 1000,
inference_params: Optional["CommonInferenceParams"] = None,
text_only: bool = False,
output_path: Optional[Union[Path, str]] = None,
output_path: Optional[AnyPath] = None,
) -> list[Union["InferenceRequest", str]]:
"""
Generates text using a NeMo LLM model.
Expand Down

0 comments on commit 8bf5144

Please sign in to comment.