-
Notifications
You must be signed in to change notification settings - Fork 515
/
byteformer.py
448 lines (393 loc) · 16.7 KB
/
byteformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
import math
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import init
from corenet.modeling.layers import (
LinearLayer,
embedding,
get_normalization_layer,
normalization,
positional_embedding,
token_merging,
)
from corenet.modeling.models import MODEL_REGISTRY, BaseAnyNNModel
from corenet.modeling.models.classification.config.byteformer import get_configuration
from corenet.modeling.modules import WindowedTransformerEncoder
def unfold_tokens(t: Tensor, kernel_size: int) -> Tensor:
"""
Group tokens from tensor @t using torch.Tensor.unfold, using the given
kernel size. This amounts to windowing @t using overlapping windows
of size @kernel_size, with overlap of @kernel_size // 2.
Args:
t: A tensor of shape [batch_size, sequence_length, num_channels].
kernel_size: The kernel size.
Returns:
A tensor of shape [batch_size * (sequence_length - kernel_size)
// (kernel_size // 2) + 1, kernel_size, num_channels].
"""
t = t.unfold(dimension=1, size=kernel_size, step=kernel_size // 2)
B, L, C, _ = t.shape
t = t.reshape(B * L, C, kernel_size)
t = t.transpose(1, 2)
return t
@MODEL_REGISTRY.register(name="byteformer", type="classification")
class ByteFormer(BaseAnyNNModel):
"""
This class defines the `ByteFormer <https://arxiv.org/abs/2306.00238>`_ architecture.
"""
def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
super().__init__(opts, *args, **kwargs)
byteformer_config = get_configuration(opts)
embed_dim = byteformer_config["embed_dim"]
ffn_dim = byteformer_config["ffn_dim"]
n_transformer_layers = byteformer_config["n_transformer_layers"]
num_heads = byteformer_config["n_attn_heads"]
attn_dropout = byteformer_config["attn_dropout"]
dropout = byteformer_config["dropout"]
ffn_dropout = byteformer_config["ffn_dropout"]
norm_layer = byteformer_config["norm_layer"]
# This is usually 257 in the case of byte inputs (2**8 + 1 mask token).
vocab_size = getattr(opts, "model.classification.byteformer.vocab_size")
self.embeddings = embedding.Embedding(
opts, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=-1
)
# Reinitialize everything except the padding index.
init.trunc_normal_(self.embeddings.weight[:-1], std=math.sqrt(1.0 / embed_dim))
self.dummy_input_token_length = getattr(
opts, "model.classification.byteformer.dummy_input_token_length"
)
# Add token reduction convolution.
self.conv_kernel_size = getattr(
opts, "model.classification.byteformer.conv_kernel_size"
)
if self.conv_kernel_size == 0:
# We skip the convolution.
self.token_reduction_net = None
if self.conv_kernel_size is not None:
self.token_reduction_net = nn.Conv1d(
embed_dim,
get_configuration(opts)["embed_dim"],
kernel_size=self.conv_kernel_size,
stride=self.conv_kernel_size // 2,
bias=False,
)
# Add the positional embeddings.
self.max_num_tokens = getattr(
opts, "model.classification.byteformer.max_num_tokens"
)
self.sinusoidal_pos_embed = getattr(
opts, "model.classification.byteformer.sinusoidal_pos_emb"
)
self.pos_embed = positional_embedding.PositionalEmbedding(
opts=opts,
num_embeddings=self.max_num_tokens,
embedding_dim=embed_dim,
sequence_first=False,
padding_idx=None,
is_learnable=not self.sinusoidal_pos_embed,
interpolation_mode="bilinear",
)
pos_emb_drop_p = getattr(opts, "model.classification.byteformer.dropout")
self.emb_dropout = nn.Dropout(p=pos_emb_drop_p)
# Build the transformer backbone.
window_sizes = getattr(opts, "model.classification.byteformer.window_sizes")
window_shifts = getattr(opts, "model.classification.byteformer.window_shifts")
downsample = getattr(opts, "model.classification.byteformer.downsample")
if len(window_sizes) == 1:
window_sizes = window_sizes * n_transformer_layers
for x in [window_sizes, window_shifts, downsample]:
if len(x) != n_transformer_layers:
raise ValueError(
f"Invalid argument length {len(x)} != {n_transformer_layers}"
)
stochastic_dropout = getattr(
opts, "model.classification.byteformer.stochastic_dropout"
)
per_layer_stochastic_drop_rate = [
round(x, 3)
for x in np.linspace(0, stochastic_dropout, n_transformer_layers)
]
blocks = []
self.downsamplers = nn.ModuleDict()
for layer_idx in range(n_transformer_layers):
blocks.append(
WindowedTransformerEncoder(
opts=opts,
embed_dim=embed_dim,
ffn_latent_dim=ffn_dim,
num_heads=num_heads,
attn_dropout=attn_dropout,
dropout=dropout,
ffn_dropout=ffn_dropout,
transformer_norm_layer=norm_layer,
stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx],
window_size=window_sizes[layer_idx],
window_shift=window_shifts[layer_idx],
)
)
if downsample is not None and downsample[layer_idx]:
self.downsamplers[self.get_downsampler_name(layer_idx)] = (
token_merging.TokenMerging(embed_dim)
)
self.transformer = nn.Sequential(*blocks)
self.post_transformer_norm = get_normalization_layer(
opts=opts, num_features=embed_dim, norm_type=norm_layer
)
num_classes = getattr(opts, "model.classification.n_classes")
self.classifier = LinearLayer(embed_dim, num_classes)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
if cls != ByteFormer:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--model.classification.byteformer.dropout",
type=float,
default=0.0,
help="Dropout in Byteformer layers. Defaults to 0.0.",
)
group.add_argument(
"--model.classification.byteformer.stochastic-dropout",
type=float,
default=0.0,
help="Probability of applying stochastic dropout to "
"TransformerEncoder submodules. Defaults to 0.0.",
)
group.add_argument(
"--model.classification.byteformer.norm-layer",
type=str,
default="layer_norm",
help="Normalization layer in Byteformer. Defaults to LayerNorm.",
choices=list(normalization.NORM_LAYER_REGISTRY.keys()),
)
group.add_argument(
"--model.classification.byteformer.sinusoidal-pos-emb",
action="store_true",
default=False,
help="Use sinusoidal instead of learnable positional encoding. Defaults to False.",
)
group.add_argument(
"--model.classification.byteformer.use-pytorch-mha",
action="store_true",
default=False,
help="Use PyTorch's native multi-head attention. Defaults to False.",
)
group.add_argument(
"--model.classification.byteformer.mode",
type=str,
default="tiny",
help="Byteformer mode, which determines the model size. Defaults to tiny.",
choices=("tiny", "small", "base", "huge"),
)
group.add_argument(
"--model.classification.byteformer.vocab-size",
type=int,
help="The vocab size of the token embedding. Defaults to 257,"
"corresponding to the number of unique bytes (256) plus 1 "
"more for the mask token.",
default=257,
)
group.add_argument(
"--model.classification.byteformer.max-num-tokens",
type=int,
help="The maximum number of tokens that can be input to the network. Defaults to 10000.",
default=10000,
)
group.add_argument(
"--model.classification.byteformer.conv-kernel-size",
type=int,
default=16,
help="The size of the kernel of the initial downsampling conv1d. Defaults to 16.",
)
group.add_argument(
"--model.classification.byteformer.window-sizes",
type=int,
nargs="*",
default=[128],
help="A list of window sizes used in shifted window attention. If the "
"list is length 1, the same window size is used for all windows. "
"Defaults to 128 for all windows.",
)
group.add_argument(
"--model.classification.byteformer.window-shifts",
type=int,
nargs="*",
default=[0, 64] * 6,
help="A list of shifts used in shifted window attention. Defaults to values that alternate between 0 and 64.",
)
default_downsampling = [True, True] + ([False, True] * 4) + [False, False]
group.add_argument(
"--model.classification.byteformer.downsample",
type=bool,
nargs="*",
default=default_downsampling,
help="A list of boolean values, where the i'th element specifies "
"whether to downsample after the transformer block with index i. "
f"Defaults to {default_downsampling}.",
)
group.add_argument(
"--model.classification.byteformer.padding-index",
default=-1,
type=int,
help="The index used for padding tokens. Defaults to -1.",
)
group.add_argument(
"--model.classification.byteformer.dummy-input-token-length",
default=48564,
type=int,
help="The token length to use for dummy inputs. Defaults to 48564, "
"corresponding to the average length of 224x224 JPEG images from "
"ImageNet.",
)
return parser
def dummy_input_and_label(self, batch_size: int) -> Dict:
"""
Get a dummy input and label that could be passed to the model.
Args:
batch_size: The batch size to use for the generated inputs.
Returns:
A dict with
{
"samples": tensor of shape [batch_size, sequence_length],
"targets": tensor of shape [batch_size],
}
"""
n_labels = 10
max_value = 257
samples = torch.randint(
0, max_value, [batch_size, self.dummy_input_token_length]
)
targets = torch.randint(low=0, high=n_labels, size=(batch_size,)).long()
return {"samples": samples, "targets": targets}
def apply_token_reduction_net(
self, x: Tensor, x_mask: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Apply the portion of the network used to reduce sequence lengths before
the transformer backbone.
Args:
x: The input token embeddings of shape [batch_size, sequence_length,
embed_dim].
x_mask: The input mask of shape [batch_size, sequence_length].
Returns:
New versions of @x and @x_mask, downsampled along the sequence
dimension by the token reduction net.
"""
B, N, C = x.shape
if self.token_reduction_net is None:
return x, x_mask
x = self.token_reduction_net(x.permute(0, 2, 1)).permute(0, 2, 1)
if x_mask is not None:
x_mask = unfold_tokens(
x_mask.reshape(B, N, 1).float(), self.conv_kernel_size
)
# The mask is now [B * N, kernel_size, 1]. It contains values in {0, -inf}.
x_mask = x_mask.max(dim=1).values.view(x.shape[0], x.shape[1])
assert x.shape[:2] == x_mask.shape
return x, x_mask
def get_backbone_inputs(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Convert input bytes into embeddings to be passed to the network's
transformer backbone.
Args:
x: The input bytes as an integer tensor of shape [batch_size,
sequence_length]. Integer tensors are expected (rather than byte
tensors) since -1 is usually used for padding.
Returns:
The embeddings of shape [batch_size, new_sequence_length] and a
mask tensor of shape [batch_size, new_sequence_length]. The mask
contains 0 at unmasked positions and float(-inf) at masked
positions.
"""
mask = torch.zeros_like(x, dtype=torch.float)
mask[x == -1].fill_(float("-inf"))
mask = mask.detach().requires_grad_(False)
x[x == -1] = self.embeddings.padding_idx
x = self.embeddings(x)
x, mask = self.apply_token_reduction_net(x, mask)
x = x + self.pos_embed(self.max_num_tokens)[:, : x.shape[1]]
x = self.emb_dropout(x)
return x, mask
def backbone_forward(
self, x: Tensor, key_padding_mask: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Execute the forward pass of the network's transformer backbone.
Args:
x: The input embeddings as a [batch_size, sequence_length, embed_dim] tensor.
key_padding_mask: The mask tensor of shape [batch_size, sequence_length].
Returns:
The outputs of the backbone as a tuple. The first element is the feature
tensor, and the second element is the updated key_padding_mask.
"""
B, S, _ = x.shape
assert key_padding_mask.shape == (B, S)
for layer_idx, elem in enumerate(self.transformer):
x = elem(x, key_padding_mask=key_padding_mask)
if self.get_downsampler(layer_idx) is not None:
x, key_padding_mask = self.get_downsampler(layer_idx)(
x, key_padding_mask
)
x = self.post_transformer_norm(x)
return x, key_padding_mask
def get_downsampler_name(self, idx: int) -> str:
"""
Get the name of the downsampling layer with index @idx.
Args:
idx: The index of the downsampling layer.
Returns:
A string representing the name of the donwsampling layer.
"""
return f"downsample_{idx}"
def get_downsampler(self, idx: int) -> Optional[nn.Module]:
"""
Get the module that performs downsampling after transformer layer @idx.
If no downsampling occurs after that layer, return None.
Args:
idx: The desired index.
Returns:
The downsampling layer, or None.
"""
name = self.get_downsampler_name(idx)
if name not in self.downsamplers:
return None
return self.downsamplers[name]
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
"""
Perform a forward pass on input bytes. The tensor is
stored as an integer tensor of shape [batch_size, sequence_length].
Integer tensors are used because @x usually contains mask tokens.
Args:
x: The input tensor of shape [batch_size, sequence_length].
Returns:
The output logits.
"""
x, key_padding_mask = self.get_backbone_inputs(x)
x, attn_mask = self.backbone_forward(x, key_padding_mask)
attn_mask = attn_mask.view(x.shape[0], x.shape[1], 1)
x[(attn_mask == float("-inf")).expand(-1, -1, x.shape[-1])] = 0
norms = (attn_mask == 0).sum(dim=1)
x = torch.sum(x, dim=1) / norms
x = self.classifier(x)
return x
@classmethod
def build_model(cls, opts: argparse.Namespace, *args, **kwargs) -> BaseAnyNNModel:
"""
Helper function to build a model.
Args:
opts: Command-line arguments.
Returns:
An instance of `corenet.modeling.models.BaseAnyNNModel`.
"""
model = cls(opts, *args, **kwargs)
if getattr(opts, "model.classification.freeze_batch_norm"):
cls.freeze_norm_layers(opts=opts, model=model)
return model