Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange conditional logic in _get_mask_indices_dims #5273

Open
hanbyul-kim opened this issue Aug 1, 2023 · 0 comments 路 May be fixed by #5274
Open

Strange conditional logic in _get_mask_indices_dims #5273

hanbyul-kim opened this issue Aug 1, 2023 · 0 comments 路 May be fixed by #5274

Comments

@hanbyul-kim
Copy link

馃悰 Bug

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run cmd '....'
  2. See error

Code sample

I think the conditional logic needs to verify whether size is present in the map to cache its computation.
However, the current condition doesn't make sense.

def _get_mask_indices_dims(self, size, padding=0, dilation=1):
if size not in self.feature_encoder_spec:
L_in = size
for (_, kernel_size, stride) in self.feature_encoder_spec:
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
self._features_size_map[size] = L_out
return self._features_size_map[size]

Expected behavior

The condition needs to verify if it is a member of self._features_size_map.

Environment

  • fairseq Version (e.g., 1.0 or main):
  • PyTorch Version (e.g., 1.0)
  • OS (e.g., Linux):
  • How you installed fairseq (pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@hanbyul-kim hanbyul-kim linked a pull request Aug 1, 2023 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant