optimum
https://github.com/huggingface/optimum
https://huggingface.co/docs/trl/v0.10.1/en/sft_trainer#using-flash-attention-1
より
pip install optimum
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
の下でtrainer.train()
facebook/opt-350m がOOMにならなくなった報告
疑問:実装の変更も必要なのでは?