diff --git a/qlora.py b/qlora.py index 59e2a701..552b9aaa 100644 --- a/qlora.py +++ b/qlora.py @@ -28,7 +28,7 @@ LlamaTokenizer ) -from datasets import load_dataset, Dataset +from datasets import load_dataset, Dataset, load_from_disk import evaluate from peft import ( @@ -481,6 +481,8 @@ def local_dataset(dataset_name): full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name)) elif dataset_name.endswith('.tsv'): full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name, delimiter='\t')) + elif dataset_name.endswith('/') or dataset_name.endswith('\\'): + full_dataset = load_from_disk(dataset_name) else: raise ValueError(f"Unsupported dataset format: {dataset_name}")