Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds temperature, top_k decoding, and top_p decoding to decode_seq2seq.py #44

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

Phirefly9
Copy link

This PR adds the capability to provide different token decoding strategies.

Pretty much a transplant of some code from https://github.com/huggingface/transformers/blob/master/examples/run_generation.py

@msftclas
Copy link

msftclas commented Dec 5, 2019

CLA assistant check
All CLA requirements met.

@Phirefly9
Copy link
Author

This is ready for review.

arguments were set so that results of the paper will not change by default when running eval.

@@ -1449,10 +1453,15 @@ def forward(self, input_ids, token_type_ids, position_ids, attention_mask, task_
last_hidden = new_encoded_layers[-1][:, -1:, :]
prediction_scores, _ = self.cls(
last_hidden, None, task_idx=task_idx)
prediction_scores = prediction_scores[:, -1, :] / (self.temperature if self.temperature > 0 else 1.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prediction_scores[:, -1, :] reduces the dimension from 3 to 2, while other places assume the dim=3 (such as prediction_scores[:, :, token_id].fill_(-10000.0)).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, I checked and prediction_scores only needed to be 3 dimensional for the not_predict_set if conditional, and is not used after that, so I modified the access to work properly with 2 dimensions.

unfortunately due to to way torch.multinomial works I cannot leave it as 3 dimensions, so had to make the modification that way

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants