Skip to content

Commit

Permalink
layout api
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed May 18, 2024
1 parent 8930215 commit 2ca0d24
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

"""Inference microbenchmark for prefill and autoregressive steps."""
import datetime
import re

import jax
import json
import sys
Expand All @@ -31,6 +33,15 @@
_WARMUP_ITERS = 2


pattern = re.compile(r"\{(.*?):")

# Extract minor_to_major from str(layout) because layout doesn't have a
# minor_to_major property yet.
def extract_minor_to_major(l):
match = re.search(pattern, str(l))
return tuple(int(i) for i in match.groups()[0].split(','))


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
Expand Down Expand Up @@ -121,12 +132,17 @@ def ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name)
return (end - start).total_seconds(), decode_state


def ar_lowering(engine, params, decode_state):
lowered_generate = engine.generate.lower(params, decode_state)
compiled_generate = lowered_generate.compile()
return compiled_generate


def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_size, model_size, iters):
"""Handles warmup, running ar benchmark, and printing results."""
for _ in range(_WARMUP_ITERS):
decode_state, _ = engine.generate(params, decode_state)
jax.block_until_ready(decode_state)

time_in_s, decode_state = ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name="autoregress")
seconds_per_step = time_in_s / iters
ar_average_ms = seconds_per_step * 1000
Expand Down Expand Up @@ -270,8 +286,11 @@ def main(config):
)

if "generate" in stages_to_benchmark:
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters)
compiled_generate = ar_lowering(engine, params, decode_state)
breakpoint()

# benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
# config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
write_results(results, filename=config.inference_microbenchmark_log_file_path)
Expand Down

0 comments on commit 2ca0d24

Please sign in to comment.