Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add batch inference #584

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,8 @@ vertex_tensorboard_project: ""
# Region to create Vertex AI Tensorboard in for GCE, blank if running via XPK
# Vertex AI supported regions: https://cloud.google.com/vertex-ai/docs/general/locations#available-regions
vertex_tensorboard_region: ""

# Inference
inference_batching_mode: "static"
inference_output_path: ""
inference_profiler_batch_indices: 5,10
257 changes: 257 additions & 0 deletions MaxText/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
export CKPT_13B=gs://runner-maxtext-logs/direct_generate_param_only_llama_2_13b_checkpoint_2024-03-27-04-27/checkpoints/0/items
export DATA_13B=/mnt/disks/persist-data/data/llama13b_openorca_question.json
export idx=$(date +%Y-%m-%d-%H-%M)
python MaxText/inference.py \
MaxText/configs/base.yml \
base_output_directory=gs://morgandu-tpu/maxtext-logs/llama2-13b/inference/static \
run_name=${idx} \
tokenizer_path=assets/tokenizer.llama2 \
dataset_path=${DATA_13B} \
model_name=llama2-13b \
load_parameters_path=${CKPT_13B} \
async_checkpointing=false \
weight_dtype=bfloat16 \
scan_layers=false \
per_device_batch_size=4 \
max_prefill_predict_length=1024 \
max_target_length=2048 \
ici_fsdp_parallelism=1 \
ici_autoregressive_parallelism=-1 \
enable_profiler=true \
profiler_steps=128 \
inference_batching_mode=static \
inference_profiler_batch_indices="5,15,25" \
inference_output_path="/tmp/inference_output.json"
"""

from absl import app

import jax
import numpy as np

from typing import Sequence

import os

import maxengine
import pyconfig


import json
from typing import List

from jetstream.engine import token_utils

import max_utils

_LOGGING_INTERVAL = 10

def delete_pytree(p):
def delete_leaf(leaf):
if isinstance(leaf, jax.Array):
leaf.delete()
del leaf

jax.tree_map(delete_leaf, p)


def validate_inference_config(config):
assert config.load_full_state_path == "", "Decode doesn't operate on full states! Convert to parameter checkpoint first."\
"Using generate_param_only_checkpoint."
assert config.inference_batching_mode in ["static", "continuous"], "Inference batching mode supports only static batching and continuous batching."


def load_openorca_dataset(
dataset_path: str
) -> List[tuple[str]]:
# Load the dataset.

with open(dataset_path) as f:
dataset_json = json.load(f)

# Tokenize the prompts and completions.
prompts = dataset_json["prompts"]
results = dataset_json["results"]

return prompts, results


def static_batching(config, engine, tokenizer, n, all_padded_tokens, true_lengths, stop_tokens):

max_output_length = config.max_target_length - config.max_prefill_predict_length

params = engine.load_params()
decode_state = engine.init_decode_state()

num_slots = engine.max_concurrent_decodes
num_batch = n // num_slots + 1

profiler_batch_indices = [int(batch) for batch in config.inference_profiler_batch_indices.split(",")]
profiler_generate_steps = config.profiler_steps

print(f"{num_slots} slots.")
print(f"Total {num_batch} batches.")

slot_generation_complete = np.zeros(num_slots)
slot_generate_result_tokens = dict()
slot_generate_results = dict()

generate_results = []
all_generate_result_tokens = []

for batch in range(num_batch):
if batch in profiler_batch_indices:
profile_generate_done = False
max_utils.activate_profiler(config)
start_i = num_slots * batch
end_i = min(num_slots * (batch + 1) - 1, n - 1)

for i in range(start_i, end_i+1):
padded_tokens = all_padded_tokens[i]
true_length=true_lengths[i]
slot = i - start_i

prefill_result = engine.prefill(
params=params,
padded_tokens=padded_tokens,
true_length=true_length,
)

decode_state = engine.insert(
prefix=prefill_result,
decode_state=decode_state,
slot=slot
)
jax.block_until_ready(decode_state)
delete_pytree(prefill_result)

for slot in range(num_slots):

slot_generation_complete[slot] = 0.
slot_generate_result_tokens[slot] = []
slot_generate_results[slot] = None

for step in range(max_output_length):
decode_state, result_tokens = engine.generate(
params, decode_state
)
jax.block_until_ready(decode_state)

for i in range(start_i, end_i+1):
slot = i - start_i
slot_data = result_tokens.get_result_at_slot(slot)
slot_tokens = slot_data.tokens
slot_lengths = slot_data.lengths

token_id = slot_tokens[slot, 0].item()
if slot_lengths > max_output_length or token_id in stop_tokens:
slot_generation_complete[slot] = 1

if slot_generation_complete[slot]==0:
slot_generate_result_tokens[slot].append(token_id)
else:
slot_generate_results[slot] = tokenizer.detokenize(slot_generate_result_tokens[slot])

if batch in profiler_batch_indices and (step + 1) == profiler_generate_steps:
profile_generate_done = True
max_utils.deactivate_profiler(config)

if np.sum(slot_generation_complete[:end_i+1-start_i]) == 1:
if batch in profiler_batch_indices and profile_generate_done is False:
profile_generate_done = True
max_utils.deactivate_profiler(config)

if batch % _LOGGING_INTERVAL == 0:
print(f"All generations for batch {batch} are completed at step {step}.")

break

for i in range(start_i, end_i+1):
slot = i - start_i
all_generate_result_tokens.append(slot_generate_result_tokens[slot])
generate_results.append(slot_generate_results[slot])

if batch % _LOGGING_INTERVAL == 0:
print(f"Finished batch {batch} over {num_batch} batches.")

return generate_results, all_generate_result_tokens


def inference(config):

prompts, results = load_openorca_dataset(config.dataset_path)
n = len(prompts)

engine = maxengine.MaxEngine(config)
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
tokenizer = vocab.tokenizer
stop_tokens = [vocab.eos_id, vocab.pad_id]
print(f"stop_tokens: {stop_tokens}")

all_padded_tokens = []
true_lengths = []
for prompt in prompts:
padded_tokens, true_length = token_utils.tokenize_and_pad(
prompt,
vocab,
is_bos=True,
prefill_lengths=[config.max_prefill_predict_length]
)
all_padded_tokens.append(padded_tokens)
true_lengths.append(true_length)

if config.inference_batching_mode == "static":
generate_results, all_generate_result_tokens = static_batching(
config,
engine,
tokenizer,
n,
all_padded_tokens,
true_lengths,
stop_tokens
)
elif config.inference_batching_mode == "continuous":
raise NotImplementedError("Continuous batching is not implemented yet.")

inference_output_json = dict()
inference_output_json["prompts"] = prompts
inference_output_json["original_results"] = results
inference_output_json["generate_results"] = generate_results
inference_output_json["all_generate_result_tokens"] = all_generate_result_tokens

if config.inference_output_path:
with open(config.inference_output_path, "w", encoding="utf-8") as f:
json.dump(inference_output_json, f)

return


def main(argv: Sequence[str]) -> None:
jax.config.update('jax_default_prng_impl', 'unsafe_rbg')
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
validate_inference_config(config)
inference(config)


if __name__ == "__main__":
app.run(main)