Skip to content

Commit

Permalink
Add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
chajath committed Mar 15, 2024
1 parent 61950fd commit 3fe081c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ psutil==5.9.8
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycnite==2023.10.11
pydantic==1.10.14
pydot==2.0.0
Pygments==2.17.2
pylint==3.1.0
Expand Down Expand Up @@ -123,6 +124,7 @@ tomli==2.0.1
tomlkit==0.12.4
toolz==0.12.1
tqdm==4.66.2
transformer-engine @ git+https://github.com/NVIDIA/TransformerEngine.git@0fbc76af3733ae997394eaf82b78ff9c0498fe9
typeguard==2.13.3
typing-inspect==0.9.0
typing_extensions==4.5.0
Expand Down
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
pip3 install --no-cache-dir "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -c constraints.txt
fi
export NVTE_FRAMEWORK=jax
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
pip3 install --no-cache-dir git+https://github.com/NVIDIA/TransformerEngine.git@0fbc76af3733ae997394eaf82b78ff9c0498fe9 -c constraints.txt
fi
elif [[ $MODE == "nightly" ]]; then
# Nightly mode
Expand Down

0 comments on commit 3fe081c

Please sign in to comment.