-
Notifications
You must be signed in to change notification settings - Fork 512
/
tokenizer.py
131 lines (100 loc) · 4.05 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
# Taken from https://github.com/ml-explore/mlx-examples/blob/main/clip/tokenizer.py
# with modifications about doc-string and typing.
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple
import mlx.core as mx
import regex
class CLIPTokenizer:
"""Convert a text to tokenized index.
A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/
"""
def __init__(self, bpe_ranks: Dict[Tuple[str], int], vocab: Dict[str, int]):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self) -> str:
return "<|startoftext|>"
@property
def bos_token(self) -> int:
return self.vocab[self.bos]
@property
def eos(self) -> str:
return "<|endoftext|>"
@property
def eos_token(self) -> int:
return self.vocab[self.eos]
def bpe(self, text: str) -> List[str]:
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/e74d793a3c3c0bc9bf3fb94bb31dd16934b1b0db/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.tokenize(*args, **kwargs)
def tokenize(
self, text: str, prepend_bos: bool = True, append_eos: bool = True
) -> mx.array:
if isinstance(text, list):
return mx.array([self.tokenize(t, prepend_bos, append_eos) for t in text])
# Lower case, cleanup, and split. Hugging Face does a much,
# more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = []
if prepend_bos:
tokens.append(self.bos_token)
tokens.extend(self.vocab[t] for t in bpe_tokens)
if append_eos:
tokens.append(self.eos_token)
return mx.array(tokens)
@staticmethod
def from_pretrained(path: str) -> "CLIPTokenizer":
path = Path(path)
with open(path / "vocab.json", encoding="utf-8") as f:
vocab = json.load(f)
with open(path / "merges.txt", encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return CLIPTokenizer(bpe_ranks, vocab)