Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

packed errors #2218

Open
chg0901 opened this issue Dec 31, 2024 · 6 comments
Open

packed errors #2218

chg0901 opened this issue Dec 31, 2024 · 6 comments
Assignees
Labels
bug Something isn't working triaged This issue has been assigned an owner and appropriate label

Comments

@chg0901
Copy link

chg0901 commented Dec 31, 2024

I test the torchtune 0.5 with my A6000*4 linux PC

when I try to use packed dataset to accelerate the training, I faced this error

errors

Using flex attention for attention computation since a BlockMask was passed in.
Traceback (most recent call last):
  File "/home/cine/miniconda3/envs/tune/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torchtune/_cli/run.py", line 214, in _run_cmd
    self._run_single_device(args, is_builtin=is_builtin)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torchtune/_cli/run.py", line 108, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 803, in <module>
    sys.exit(recipe_main())
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 798, in recipe_main
    recipe.train()
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 707, in train
    current_loss.backward()
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2048, in backward
    out = call_compiled_backward()
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1954, in call_compiled_backward
    CompiledFunction.compiled_bw = aot_config.bw_compiler(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 51, in _wrapped_bw_compiler
    return disable(disable(bw_compiler)(*args, **kwargs))
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1466, in bw_compiler
    return inner_compile(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1370, in load
    compiled_graph = compile_fx_fn(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 878, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1913, in compile_to_fn
    return self.compile_to_module().call
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1839, in compile_to_module
    return self._compile_to_module()
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1867, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2876, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_cine/al/calfe7ti75mdabcy4jy6oe7kidirl3nyvoolgrkunuzqik4lzmdn.py", line 830, in <module>
    async_compile.wait(globals())
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/async_compile.py", line 276, in wait
    scope[key] = result.result()
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 3344, in result
    self.kernel.precompile()
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 250, in precompile
    raise e
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 244, in precompile
    compiled_binary, launcher = self._precompile_config(
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 452, in _precompile_config
    binary._init_handles()
  File "/home/cine/miniconda3/envs/tune/lib/python3.10/site-packages/triton/compiler/compiler.py", line 374, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: 
out of resource: shared memory, Required: 131074, Hardware limit: 101376. 
Reducing block sizes or `num_stages` may help.

configs

output_dir: ./lora_single_device_output/Llama-2-7b-hf/ 

# Model Arguments
model:
  _component_: torchtune.models.llama2.lora_llama2_7b
  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
  apply_lora_to_mlp: True
  apply_lora_to_output: False
  lora_rank: 8  # higher increases accuracy and memory
  lora_alpha: 16  # usually alpha=2*rank
  lora_dropout: 0.0

tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  path:  ./models/Llama-2-7b-hf/tokenizer.model
  max_seq_len: 1024

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: ./models/Llama-2-7b-hf
  checkpoint_files: [
    pytorch_model-00001-of-00002.bin,
    pytorch_model-00002-of-00002.bin
  ]
  adapter_checkpoint: null
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: True  # True increases speed
seed: null
shuffle: True
batch_size: 1

# Optimizer and Scheduler
optimizer:
  _component_: torch.optim.AdamW
  fused: True
  weight_decay: 0.01
  lr: 3e-4
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100

loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8  # Use to increase effective batch size
compile: True # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True  # True reduces memory
enable_activation_offloading: True # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: False

  #Output directory of trace artifacts
  output_dir: ${output_dir}/profiling_outputs

  #`torch.profiler.ProfilerActivity` types to trace
  cpu: True
  cuda: True

  #trace options passed to `torch.profiler.profile`
  profile_memory: False
  with_stack: False
  record_shapes: True
  with_flops: False

  # `torch.profiler.schedule` options:
  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  wait_steps: 5
  warmup_steps: 5
  active_steps: 2
  num_cycles: 1

also tried to modify batch size

tokenizer.max_seq_len: 1024
dataset.packed: True # True

batch_size: 1
gradient_accumulation_steps: 16  # Use to increase effective batch size

you could check A blog in Chiense for more details.

internLM support

by the way, I want to add a PR for internLM LLM, is there any work in process?

@ebsmothers
Copy link
Contributor

Hi @chg0901 thanks for creating the issue. Our packed dataset implementation uses flex attention under the hood to support the necessary block causal mask while still retaining good performance. Unfortunately there are some nuances here -- specifically flex attention hardcodes some kernel configs depending on the type of hardware you're using, and these aren't currently optimized for A6000. I would check this comment (along with others in the same thread for more context) for one way to get around this in the short term. This is a known issue in PyTorch core (see pytorch/pytorch#133254) and longer-term, the flex attention authors are working on fixing this -- see pytorch/pytorch#137959 (I believe the compute_capability == (8, 6) case in that PR corresponds to A6000). So one suggestion is to try hardcoding the kernel options as a temporary fix. If this works, we can try to figure out a way to support this in the interim to make the process a bit less painful.

Regarding internLM, we aren't currently working on enabling it. Can you open a separate issue with a formal feature request? It'd be helpful to gauge community interest before opening a PR.

@dz1iang
Copy link

dz1iang commented Jan 6, 2025

The following configuration encounters the same issue on L40s.

# Tokenizer
tokenizer:
  _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
  path: /model/Qwen2.5-7B-Instruct/vocab.json
  merges_file: /model/Qwen2.5-7B-Instruct/merges.txt
  max_seq_len: 1024

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: True  # True increases speed
  source: /data/alpaca-cleaned
seed: null
shuffle: True

compile: False

@joecummings joecummings added triaged This issue has been assigned an owner and appropriate label bug Something isn't working labels Jan 6, 2025
@dz1iang
Copy link

dz1iang commented Jan 7, 2025

The following configuration encounters the same issue on L40s.

# Tokenizer
tokenizer:
  _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
  path: /model/Qwen2.5-7B-Instruct/vocab.json
  merges_file: /model/Qwen2.5-7B-Instruct/merges.txt
  max_seq_len: 1024

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: True  # True increases speed
  source: /data/alpaca-cleaned
seed: null
shuffle: True

compile: False

and it can work on Qwen2.5-0.5B-Instruct

@mirceamironenco
Copy link
Contributor

It likely works on 0.5B because the head_dim is smaller. As another temporary suggestion for anyone blocked, I'm pretty sure you can make it work on L40s if you change the way flex is compiled to let it find a kernel that is compatible with the cuda smem of the machine, i.e. change:

flex_attention_compiled = torch.compile(flex_attention, dynamic=False)

to: flex_attention_compiled = torch.compile( flex_attention, dynamic=False, mode="max_autotune" )

Alternatively, you can turn off flex attention by hard coding _SUPPORTS_FLEX_ATTENTION = False which will still allow for packed=True.

@chg0901
Copy link
Author

chg0901 commented Jan 7, 2025 via email

@mirceamironenco
Copy link
Contributor

mirceamironenco commented Jan 7, 2025

max_autotune will likely work for any model/device, assuming the problem is as described (triton kernel OOM). And if it doesn't, setting _SUPPORTS_FLEX_ATTENTION = False should work as you fall back to F.sdpa.

To be clear, you will need to uninstall whatever version of torchtune you have, clone the repo, make the change I described (either compile with max_autotune or turn off flex attention) and then pip install . from the local repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged This issue has been assigned an owner and appropriate label
Projects
None yet
Development

No branches or pull requests

5 participants