Skip to content

Commit

Permalink
Merge pull request #520 from google:jonbolin/test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616874915
  • Loading branch information
maxtext authors committed Mar 18, 2024
2 parents 3aebce9 + c07d1c7 commit 5353a95
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 32 deletions.
28 changes: 14 additions & 14 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,39 @@ jobs:
- name: Test train.py with c4
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false'
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false'
- name: Test train.py with synthetic data
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false dataset_type=synthetic'
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false dataset_type=synthetic'
- name: Test train.py with per_device_batch_size < 1
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false'
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false'
- name: Test decode.py
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=1'
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=1'
- name: Test decode.py with per_device_batch_size < 1
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=.25'
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=.25'
- name: Test standalone_dataloader.py
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/standalone_dataloader.py MaxText/configs/base.yml run_name=standalone_dataloader_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=100 enable_checkpointing=false'
'python3 MaxText/standalone_dataloader.py MaxText/configs/base.yml run_name=standalone_dataloader_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=100 enable_checkpointing=false'
- name: Test standalone_checkpointer.py
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/standalone_checkpointer.py MaxText/configs/base.yml run_name=standalone_checkpointer_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=200 checkpoint_period=50 enable_checkpointing=True async_checkpointing=False'
'python3 MaxText/standalone_checkpointer.py MaxText/configs/base.yml run_name=standalone_checkpointer_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=200 checkpoint_period=50 enable_checkpointing=True async_checkpointing=False'
- name: Test int8_training
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false'
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false'
- name: Test generate_param_only_checkpoint
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4'
'bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4'
- name: Test grain checkpoint determinism
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
Expand Down Expand Up @@ -145,20 +145,20 @@ jobs:
- name: Test train.py
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=dot_product'
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=dot_product'
- name: Test train.py with per_device_batch_size < 1
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=dot_product'
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=dot_product'
- name: Test int8_training
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=dot_product'
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=dot_product'
- name: Test decode.py
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=1'
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=1'
- name: Test decode.py with per_device_batch_size < 1
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=.25'
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=.25'
2 changes: 1 addition & 1 deletion end_to_end/llama_finetuning_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# - Using the checkpoint generated from train.py or available one in open source (https://llama.meta.com/llama-downloads/).

set -e
idx=$(date +%Y-%m-%d-%H-%M)
idx=$(date +%Y-%m-%d-%H-%M)-$RANDOM

base_ckpt_path=gs://maxtext-llama/test/2024-01-15-06-49/decode-ckpt-maxtext/0/items
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/test_checkpoint_compatibility.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ if [ -f "run_*_metrics.txt" ]; then
echo "removed existing run_*_metrics.txt"
fi

RUN_NAME=${1}-$(date +%Y-%m-%d-%H-%M)
RUN_NAME=${1}-$(date +%Y-%m-%d-%H-%M)-${RANDOM}
OUTPUT_PATH=${2}
DATASET_PATH=${3}
model_params=" base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128"
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/test_checkpointing.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ if [ -f "restored_metrics.txt" ]; then
echo "removed existing restored_metrics.txt"
fi

RUN_NAME=${1}-${4}-$(date +%Y-%m-%d-%H-%M)
RUN_NAME=${1}-${4}-$(date +%Y-%m-%d-%H-%M)-${RANDOM}
OUTPUT_PATH=${2}
DATASET_PATH=${3}
COLLECT_STACK_TRACE=${4}
Expand Down
4 changes: 2 additions & 2 deletions end_to_end/test_decode.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ DATASET_PATH=${3}

if [ -z ${4} ]
then
RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S)
RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S)-${RANDOM}
else
RUN_NAME=${4}_$(date +%Y-%m-%d-%H)
RUN_NAME=${4}_$(date +%Y-%m-%d-%H)-${RANDOM}
fi

if [ -z ${5} ]
Expand Down
12 changes: 6 additions & 6 deletions end_to_end/test_gemma.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/bin/bash
set -ex
idx=$(date +%Y-%m-%d-%H-%M)
idx=$(date +%Y-%m-%d-%H-%M)-${RANDOM}
# convert 2.5B checkpoint
export base_model_path=gs://maxtext-gemma/flax/2b
export maxtext_model_path=gs://maxtext-gemma/2b/${idx}
python MaxText/convert_gemma_chkpt.py --base_model_path ${base_model_path} --maxtext_model_path ${maxtext_model_path} --model_size 2b
# Test Gemma 2.5B decode
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I"
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=weighted decode_sampling_temperature=.00001
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=nucleus decode_sampling_nucleus_p=0
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=topk decode_sampling_top_k=1
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_${idx} max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I"
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_${idx} max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=weighted decode_sampling_temperature=.00001
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_${idx} max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=nucleus decode_sampling_nucleus_p=0
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_${idx} max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=topk decode_sampling_top_k=1

# convert 7B checkpoint
export base_model_path=gs://maxtext-gemma/flax/7b
export maxtext_model_path=gs://maxtext-gemma/7b/${idx}
python MaxText/convert_gemma_chkpt.py --base_model_path ${base_model_path} --maxtext_model_path ${maxtext_model_path} --model_size 7b
# Test Gemma 7B decode
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" use this product in my hair. It"
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_${idx} max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" use this product in my hair. It"
2 changes: 1 addition & 1 deletion end_to_end/test_llama2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# - Using the checkpoint generated from train.py or available one in open source (https://llama.meta.com/llama-downloads/).

set -ex
idx=$(date +%Y-%m-%d-%H-%M)
idx=$(date +%Y-%m-%d-%H-%M)-${RANDOM}

export M_ENABLE_CHECKPOINTING=true
export M_BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/test_mistral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# - Using the checkpoint generated from train.py or available one in open source (i.e. https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar).

set -ex
idx=$(date +%Y-%m-%d-%H-%M)
idx=$(date +%Y-%m-%d-%H-%M)-${RANDOM}

export M_ENABLE_CHECKPOINTING=true
export M_BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# - Using the checkpoint generated from train.py or available one in open source (i.e. https://files.mixtral-8x7b-v0-1.mistral.ai/Mixtral-8x7B-v0.1-Instruct.tar).

set -ex
idx=$(date +%Y-%m-%d-%H-%M)
idx=$(date +%Y-%m-%d-%H-%M)-${RANDOM}

export M_ENABLE_CHECKPOINTING=true
export M_BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
Expand Down
4 changes: 2 additions & 2 deletions end_to_end/test_tflops.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ DATASET_PATH=${4}

if [ -z ${5} ]
then
RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S)
RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S)-${RANDOM}
else
RUN_NAME=${5}_$(date +%Y-%m-%d-%H)
RUN_NAME=${5}_$(date +%Y-%m-%d-%H)-${RANDOM}
fi

#Train
Expand Down
4 changes: 2 additions & 2 deletions gke/gpu/start_training.sh
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ set -e
PIDS=()
for ((LOCAL_DEVICE_ID=0; LOCAL_DEVICE_ID <= $((GPUS_PER_NODE - 1)); LOCAL_DEVICE_ID++)); do
PROCESS_ID=$(($GPUS_PER_NODE*$NODE_RANK + $LOCAL_DEVICE_ID))
LOCAL_DEVICE_ID=$LOCAL_DEVICE_ID PROCESS_ID=$PROCESS_ID python MaxText/train.py MaxText/configs/base.yml hardware=gpu run_name=${RUN_NAME}_$(date +%Y-%m-%d-%H-%M) &
LOCAL_DEVICE_ID=$LOCAL_DEVICE_ID PROCESS_ID=$PROCESS_ID python MaxText/train.py MaxText/configs/base.yml hardware=gpu run_name=${RUN_NAME}_$(date +%Y-%m-%d-%H-%M)-${RANDOM} &
PID=$!
PIDS+=($PID)
echo "Launched MaxText/train.py for local_device_id: $LOCAL_DEVICE_ID process_id: $PROCESS_ID and PID $PID"
done

wait_all_success_or_exit "${PIDS[@]}"
wait_all_success_or_exit "${PIDS[@]}"

0 comments on commit 5353a95

Please sign in to comment.