Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-process batch size not calculated correctly #19726

Closed
natbprice opened this issue May 17, 2024 · 6 comments
Closed

Multi-process batch size not calculated correctly #19726

natbprice opened this issue May 17, 2024 · 6 comments
Assignees
Labels

Comments

@natbprice
Copy link

Describe the bug
I opened a related issue in keras-nlp, but I believe the issue is likely best addressed in keras. See related issue: keras-team/keras-nlp#1630

Currently, the batch size is not calculated correctly when performing multi-process distributed training with JAX backend if the dataset has been pre-processed with a mapping function.

ValueError                                Traceback (most recent call last)
[<ipython-input-9-639e39591e79>](https://localhost:8080/#) in <cell line: 14>()
     12 
     13 model.compile(loss="mse")
---> 14 model.fit(ds, epochs=3)

1 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

[/usr/local/lib/python3.10/dist-packages/keras/src/distribution/distribution_lib.py](https://localhost:8080/#) in distribute_dataset(self, dataset)
    465         batch_size = tf_data_distribute.compute_batch_size(dataset)
    466         if batch_size.numpy() < 0:
--> 467             raise ValueError(
    468                 "The batch size of the input dataset is "
    469                 "unknown. Please config the batch size for "

ValueError: The batch size of the input dataset is unknown. Please config the batch size for the input dataset, e.g via `dataset.batch(batch_size)`

To Reproduce
See https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I?usp=sharing

import os

os.environ['KERAS_BACKEND'] = 'jax'

import keras
import tensorflow as tf
import numpy as np

print(f"keras", keras.__version__)
print(f"tf", tf.__version__)

data_parallel = keras.distribution.DataParallel()

# Mock multi-process environment
data_parallel._is_multi_process = True

keras.distribution.set_distribution(data_parallel)

inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
ds = tf.data.Dataset.from_tensor_slices((inputs, labels)).batch(16)
ds = ds.map(lambda x,y: (x,y))

inputs = keras.layers.Input(shape=(28, 28, 1))
y = keras.layers.Flatten()(inputs)
y = keras.layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = keras.layers.Dropout(0.4)(y)
y = keras.layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)

model.compile(loss="mse")
model.fit(ds, epochs=3)

Expected behavior
A batched tf.data.Dataset() object is recognized as being batched.

Would you like to help us fix it?
I would like to try to fix this if it is not too complex. Maybe we can just replace call to tensorflow.python.data.experimental.ops.distribute.compute_batch_size() with dataset._input_dataset._batch_size?

@SuryanarayanaY
Copy link
Collaborator

Hi @natbprice ,

Thanks for reporting. I have tested the code snippet and reproduced the reported behaviour. Attached gist for reference.

@SuryanarayanaY SuryanarayanaY added type:Bug backend:jax keras-team-review-pending Pending review by a Keras team member. labels May 20, 2024
@mattdangerw mattdangerw removed the keras-team-review-pending Pending review by a Keras team member. label May 23, 2024
@hertschuh
Copy link
Contributor

@natbprice ,

Thanks for the report and the investigation. After looking into it in details, I came to the conclusion that this works as expected.

I saw your proposed fix:

  if isinstance(ds, _MapDataset) or isinstance(ds, _ParallelMapDataset):
    return ds._input_dataset._batch_size

But that's the batch size of the input dataset. The issue is that there is no constraint on what the function passed to map is allowed to do, therefore there is no guarantee that what comes out of map has the same batch size as what came in.

Now, why does this only happen when using multi-process distribution? That's because Keras is able to train with an unknown batch size in the normal case and only tries to determine the batch size if distribution is turned on.

What's the fix? Well, the standard pattern I've seen used is to batch last, after map, shuffle etc.

ds = tf.data.Dataset.from_tensor_slices((inputs, labels))
ds = ds.map(lambda x,y: (x,y))
ds = ds.batch(16)

Let me know if you have further questions.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@natbprice
Copy link
Author

@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.

Sorry, if I created extra work. I guess I should have not opened related issue here.

@hertschuh
Copy link
Contributor

@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.

Sorry, if I created extra work. I guess I should have not opened related issue here.

@natbprice ,

Yes, I think the fix should be in keras-nlp. One should simply apply batch_size after the map and not in _convert_inputs_to_dataset. Do you want me to follow up in the keras-nlp bug?

@natbprice
Copy link
Author

@hertschuh if you don't mind following up in keras-nlp, that would be great! I think I understand the solution you are proposing, but I can't quite figure out the best way for keras-nlp API to function. In particular, it seems like there are several combinations of (1) distribution strategy, (2) input types (e.g., tf.data.Dataset, NumPy arrays), and (3) batching (e.g., pre-batched dataset, explicit batch_size).

Currently, in _convert_inputs_to_dataset it will raise an error if you attempt to pass a tf.data.Dataset with explicit batch_size argument. It also looks like there is error handling to prevent you from passing unbatched inputs, but the string matching on the error message may be oudated and not functioning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants