Skip to content

Commit

Permalink
added support to use megatron tflops calculation
Browse files Browse the repository at this point in the history
added config

adding support for megatron style tflops calculation
  • Loading branch information
abhinavgoel95 committed Mar 27, 2024
1 parent 5353a95 commit 3980d41
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,4 @@ decode_sampling_temperature: 1.

eval_interval: -1 # the specific number of train step between eval_step
target_eval_loss: 0. # early stop once reaching target eval_loss
use_megatron_tflops: False # use tflops calculation used in https://arxiv.org/abs/2205.05198 (more accurate for transformer models)
26 changes: 20 additions & 6 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,33 @@ def get_train_input_output_trees(func, input_args, input_kwargs):
# https://arxiv.org/pdf/2204.02311.pdf Appendix B
def calculate_tflops_training_per_device(num_model_parameters, config, log=True):
""" Calculate training TFLOP"""
learnable_weight_tflops = 6 * num_model_parameters * config.max_target_length * config.per_device_batch_size \
if config.use_megatron_tflops:
attention_flops = 2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
ffn1_flops = 2 * config.per_device_batch_size * config.max_target_length * config.mlp_dim * config.emb_dim * len(config.mlp_activations)
ffn2_flops = 2 * config.per_device_batch_size * config.max_target_length * config.mlp_dim * config.emb_dim
qkv_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * (config.num_query_heads + 2 * config.num_kv_heads) * config.head_dim
proj_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_query_heads * config.head_dim
embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size

# multiplied by 3 to account for fprop and bprop flops
learnable_weight_tflops = (ffn1_flops + ffn2_flops + qkv_flops + proj_flops) * config.num_decoder_layers * 3 / 10**12
attention_tflops = attention_flops * config.num_decoder_layer * 3 / 10**12 # megatron tflops calculation does not account for causality in attention

else:
learnable_weight_tflops = 6 * num_model_parameters * config.max_target_length * config.per_device_batch_size \
/ 10**12
noncasual_attention_flops = 12 * config.num_query_heads * config.num_decoder_layers * config.head_dim \
noncasual_attention_flops = 12 * config.num_query_heads * config.num_decoder_layers * config.head_dim \
* config.max_target_length**2 * config.per_device_batch_size / 10**12
causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention
total_tflops = learnable_weight_tflops + causal_attention_tflops
attention_tflops = noncasual_attention_flops / 2 # divided by 2 due to causality in attention

total_tflops = learnable_weight_tflops + attention_tflops

if log:
print('Per train step:\n',
f'Total TFLOPs: {total_tflops:.2f} \n',
f'split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops',
f'and {100 * causal_attention_tflops/total_tflops:.2f}% attention flops')
return total_tflops, learnable_weight_tflops, causal_attention_tflops
f'and {100 * attention_tflops/total_tflops:.2f}% attention flops')
return total_tflops, learnable_weight_tflops, attention_tflops

# https://arxiv.org/pdf/2204.02311.pdf Appendix B
def calculate_tflops_prefill(num_model_parameters, prefill_length, config, log=True):
Expand Down

0 comments on commit 3980d41

Please sign in to comment.