Skip to content

Commit

Permalink
Update logs and error messages
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
janekl committed Jan 21, 2025
1 parent 886a8c8 commit 0080371
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
14 changes: 10 additions & 4 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,19 +318,25 @@ def export(
model_type = get_model_type(nemo_checkpoint_path)

if model_type is None:
raise Exception("Parameter model_type needs to be specified, got None.")
raise Exception(
"Parameter model_type needs to be provided and cannot be inferred from the checkpoint. "
"Please specify it explicitely."
)

if model_type not in self.get_supported_models_list:
raise Exception(
"Model {0} is not currently a supported model type. "
"Supported model types are: {1}.".format(model_type, self.get_supported_models_list)
f"Model {model_type} is not currently a supported model type. "
f"Supported model types are: {self.get_supported_models_list}."
)

if dtype is None:
dtype = get_weights_dtype(nemo_checkpoint_path)

if dtype is None:
raise Exception("Parameter dtype needs to be specified, got None.")
raise Exception(
"Parameter dtype needs to be provided and cannot be inferred from the checkpoint. "
"Please specify it explicitely."
)

model, model_config, self.tokenizer = load_nemo_model(
nemo_checkpoint_path, nemo_export_dir, use_mcore_path
Expand Down
6 changes: 5 additions & 1 deletion nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def get_weights_dtype(nemo_ckpt: Union[str, Path]) -> Optional[str]:
"""
model_config = load_nemo_config(nemo_ckpt)
torch_dtype = None
dtype = None

is_nemo2 = "_target_" in model_config
if is_nemo2:
Expand All @@ -502,7 +503,10 @@ def get_weights_dtype(nemo_ckpt: Union[str, Path]) -> Optional[str]:
dtype = torch_dtype.removeprefix("torch.")
LOGGER.info(f"Determined weights dtype='{dtype}' for {nemo_ckpt} checkpoint.")
else:
LOGGER.warning(f"Parameter dtype for model weights cannot be determined for {nemo_ckpt} checkpoint.")
LOGGER.warning(
f"Parameter dtype for model weights cannot be determined for {nemo_ckpt} checkpoint. "
"There is no 'precision' field specified in the model_config.yaml file."
)

return dtype

Expand Down

0 comments on commit 0080371

Please sign in to comment.