Skip to content

Commit

Permalink
[usability] use_auth_token deprecation update
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhenjia committed Nov 5, 2024
1 parent 97e1d3a commit e23ad92
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 14 deletions.
14 changes: 5 additions & 9 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ class ModelArguments:
a string representing the specific model version to use (can be a
branch name, tag name, or commit id).
use_auth_token : bool
a boolean indicating whether to use the token generated when running
huggingface-cli login (necessary to use this script with private models).
token : Optional[str]
Necessary when accessing a private model/dataset.
torch_dtype : str
a string representing the dtype to load the model under. If auto is
Expand Down Expand Up @@ -180,13 +179,10 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
token: Optional[str] = field(
default=None,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
"help": ("Necessary to specify when accessing a private model/dataset.")
},
)
trust_remote_code: bool = field(
Expand Down
1 change: 0 additions & 1 deletion src/lmflow/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface",
data_files=data_files,
field=KEY_INSTANCES,
split="train",
use_auth_token=None,
)
self.backend_dataset = raw_dataset
self._check_data_format()
Expand Down
4 changes: 2 additions & 2 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def get_peft_without_qlora(self):
config_kwargs = {
"cache_dir": self.model_args.cache_dir,
"revision": self.model_args.model_revision,
"use_auth_token": True if self.model_args.use_auth_token else None,
"token": self.model_args.token,
}
config = AutoConfig.from_pretrained(self.model_args.model_name_or_path, **config_kwargs)
device_map = "auto"
Expand All @@ -632,7 +632,7 @@ def get_peft_without_qlora(self):
config=config,
cache_dir=self.model_args.cache_dir,
revision=self.model_args.model_revision,
use_auth_token=True if self.model_args.use_auth_token else None,
token=self.model_args.token,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code = self.model_args.trust_remote_code,
Expand Down
4 changes: 2 additions & 2 deletions src/lmflow/models/hf_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __prepare_tokenizer(
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"token": model_args.token,
"trust_remote_code": model_args.trust_remote_code,
}
if model_args.padding_side != 'auto':
Expand Down Expand Up @@ -203,7 +203,7 @@ def __prepare_model_config(
"attn_implementation": "flash_attention_2" if model_args.use_flash_attention else None,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"token": model_args.token,
"trust_remote_code": model_args.trust_remote_code,
"from_tf": bool(".ckpt" in model_args.model_name_or_path),
}
Expand Down

0 comments on commit e23ad92

Please sign in to comment.