Commit 7c17cb48 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

Option to mask whole words, or tokens

parent b92ff5c1
import torch
import random
def eval_model(model, loader, mask_idx, device):
pos = 0
total = 0
with torch.no_grad():
for (x, y) in loader:
output = model(x.to(device))
masks = torch.nonzero(x == mask_idx)
i, j = masks[:, 0], masks[:, 1]
output_probs = output[i, j, :]
output_labels = torch.argmax(output_probs, dim=1)
pos += torch.sum(output_labels == y.to(device)).item()
total += x.size(0)
return pos / total, pos
def mask_random_position(x, mask_idx, pad_idx):
not_pad = x != pad_idx
masking_indices = torch.multinomial(not_pad.float(), 1)
labels = torch.gather(x, 1, masking_indices)
x[torch.arange(masking_indices.size(0)),
masking_indices.flatten()] = mask_idx
return (x, labels)
def eval_model_k(model, loader, mask_idx, pad_idx, device, k=1):
pos = [0 for _ in range(k)]
total = 0
with torch.no_grad():
for (x, attention_mask, wids) in loader:
x = x.to(device)
x, y = mask_random_position(x, mask_idx, pad_idx)
output = model(x)
masks = torch.nonzero(x == mask_idx)
i, j = masks[:, 0], masks[:, 1]
output_probs = output[i, j, :]
k_output_labels = torch.argsort(
output_probs, dim=1, descending=True)[:, :k]
for label, neighbors in zip(y, k_output_labels):
# print(label.item())
for kk in range(k):
if label.item() in neighbors[:(kk + 1)]:
pos[kk] += 1
total += x.size(0)
return [p / total for p in pos], pos
import enum
import gzip
from pathlib import Path
import tokenizers
import torch
from torch.utils import data
from torch.utils.data import DataLoader, IterableDataset
from tokenizers import Tokenizer
class LineDataset(IterableDataset):
def __init__(self, path, lines_count=None):
self.path = str(path)
self.lines_count = lines_count
def __len__(self):
if self.lines_count is None:
with gzip.open(self.path, "rt", encoding="UTF-8") as f:
i = 0
for i, _ in enumerate(f, 1):
pass
self.lines_count = i
return self.lines_count
def __iter__(self):
with gzip.open(self.path, "rt", encoding="UTF-8") as f:
for l in f:
yield l
class TrainData:
def __init__(self, path, pad_token="<pad>", max_seq_size=256):
self.path = Path(path)
tokenizer_path = Path(path, "tokenizer.json")
self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
self.pad_token = pad_token
self.pad_id = self.tokenizer.token_to_id(pad_token)
self.tokenizer.enable_padding(
pad_id=self.pad_id, pad_token=pad_token, length=max_seq_size)
self.max_seq_size = max_seq_size
self.train_dataset_path = Path(path, "datasets", "train.txt.gz")
self.dev_dataset_path = Path(path, "datasets", "dev.txt.gz")
self.test_dataset_path = Path(path, "datasets", "test.txt.gz")
self.train_dataset = LineDataset(self.train_dataset_path)
self.dev_dataset = LineDataset(self.dev_dataset_path)
self.test_dataset = LineDataset(self.test_dataset_path)
def iter_train(self):
return self.train_dataset
def iter_dev(self):
return self.dev_dataset
def iter_test(self):
return self.test_dataset
def encode(self, sentence):
return self.tokenizer.encode(sentence)
def decode(self, token_ids, *args, **kwargs):
return self.tokenizer.decode(token_ids, *args, **kwargs)
def to_tensor(self, sentences):
encoded = self.tokenizer.encode_batch(sentences)
sequence_tensor = torch.full(
(len(encoded), self.max_seq_size), self.pad_id, dtype=torch.long)
attention_mask_tensor = torch.zeros(
(len(encoded), self.max_seq_size), dtype=torch.float)
word_ids_tensor = torch.full(
(len(encoded), self.max_seq_size), -1, dtype=torch.long)
for i, encoded in enumerate(self.tokenizer.encode_batch(sentences)):
sequence_tensor[i, :] = torch.tensor(encoded.ids)
attention_mask_tensor[i, :] = torch.tensor(encoded.attention_mask)
for j, wid in enumerate(encoded.word_ids):
if wid is None:
break
word_ids_tensor[i, j] = wid
return sequence_tensor, attention_mask_tensor, word_ids_tensor
if __name__ == "__main__":
from itertools import islice
datasets = TrainData("output/oscar-100")
tokenizer = datasets.tokenizer
train_loader = DataLoader(datasets.iter_train(),
collate_fn=datasets.to_tensor)
dev_loader = DataLoader(datasets.iter_dev(),
collate_fn=datasets.to_tensor)
test_loader = DataLoader(datasets.iter_test(),
collate_fn=datasets.to_tensor)
from masking import TensorMasker
mask_id = tokenizer.token_to_id("<mask>")
min_word_id = 4
max_word_id = tokenizer.get_vocab_size() - 1
tm = TensorMasker(mask_id, min_word_id, max_word_id)
for (x, attention_mask, word_ids) in islice(dev_loader, 3):
masked, masking_mask = tm(x, attention_mask, word_ids)
print(datasets.decode(x.flatten().tolist()))
print(datasets.decode(masked.flatten().tolist(), skip_special_tokens=True))
print(masking_mask)
from pathlib import Path
from tokenizers import Tokenizer
import tokenizers
from load import *
from evaluation import *
if __name__ == "__main__":
import argparse
import sys
sys.path.append("C:/Users/gaeta/Documents/Recherches/software/minibert")
from minibert import *
parser = argparse.ArgumentParser()
parser.add_argument("model")
parser.add_argument("datadir")
parser.add_argument("-m", "--mask", default="<mask>")
parser.add_argument("-c", "--checkpoint", action="store_true")
args = parser.parse_args()
if args.checkpoint:
checkpoint = torch.load(args.model)
configuration_dict = checkpoint["configuration"]
device = checkpoint["device"]
configuration = MiniBertForMLMConfiguration(**configuration_dict)
prev_epoch = checkpoint["epoch"]
model = MiniBertForMLM(configuration).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
else:
raise NotImplementedError()
model.eval()
data = TrainData(args.datadir)
tokenizer = data.tokenizer
while True:
#sentence = input("Sentence: ").strip()
sentence = "je avoir vouloir reprendre je me être aperçu cone que je avoir supprimer mail contenir premier épisode"
spl = sentence.split()
sentence_is_masked = args.mask in sentence
if sentence_is_masked:
if args.mask not in sentence:
continue
mask_position = 0
for i, x in enumerate(spl):
if x == args.mask:
mask_position = i
break
sentences = [sentence]
else:
sentences = [" ".join(spl[:i] + ["<mask>"] + spl[(i+1):])
for i in range(len(spl))]
tokens = data.to_tensor(sentences)
output = model(tokens)
for i, y in enumerate(output):
if not sentence_is_masked:
mask_position = i
j = torch.argmax(y[mask_position, :])
# print(y.size())
# print(i)
found = tokenizer.decode([j])
print(sentences[i])
print(f"-> {found} | {j}")
break
# je avoir vouloir reprendre je me être aperçu cone que je avoir supprimer mail contenir premier épisode
import torch
# def is_subword_token(token, prefix):
# return token.startswith(prefix)
# def find_first_word_position(tokens, pos, prefix):
# for i, tok in enumerate(tokens[pos:0:-1]):
# if not is_subword_token(tok, prefix):
# return pos - i
# return 0 # Should be unreachable
# def find_last_word_position(tokens, pos, prefix):
# for i, tok in enumerate(tokens[pos+1:]):
# if not is_subword_token(tok, prefix):
# return pos + i
# return len(tokens) - 1
# def mask_positions(tokens, positions, prefix, mask_token):
# for i in positions:
# # Skip if already masked
# if tokens[i] == mask_token:
# continue
# fst = find_first_word_position(tokens, i, prefix)
# lst = find_last_word_position(tokens, i, prefix)
# for j in range(fst, lst + 1):
# tokens[j] = mask_token
# def mask_sentence(sentence, tokenizer):
# pass
class TensorMasker:
def __init__(self, mask_id, min_word_id, max_word_id, masking_ratio=0.15, masking_probability=0.8, corruption_probability=0.1, mask_whole_words=True):
self.mask_id = mask_id
self.masking_ratio = masking_ratio
self.masking_probability = masking_probability
self.corruption_probability = corruption_probability
self.mask_whole_words = mask_whole_words
self.min_word_id = min_word_id
self.max_word_id = max_word_id
def __call__(self, sequence_tensor, attention_mask, word_ids=None):
masking_mask = torch.zeros_like(sequence_tensor, dtype=torch.bool)
masked = sequence_tensor.clone().detach()
for i in range(len(masked)):
cols = attention_mask[i, :] > 0
if self.mask_whole_words:
masked[i, cols], masking_mask[i, cols] = self.mask_words(
masked[i, cols], word_ids[i, cols])
else:
masked[i, cols], masking_mask[i,
cols] = self.mask_tokens(masked[i, cols])
return masked, masking_mask
def mask_words(self, tokens, word_ids):
unique_word_ids = torch.unique(word_ids)
words_to_mask = torch.rand_like(
unique_word_ids, dtype=torch.float) <= self.masking_probability
words_to_mask = unique_word_ids[words_to_mask]
i = torch.tensor([wid in words_to_mask for wid in word_ids])
masking_mask = torch.rand_like(i, dtype=torch.float)
i_masked = i & (masking_mask <= self.masking_probability)
i_corrupted = i & ((1.0 - self.corruption_probability) <= masking_mask)
tokens[i_masked] = self.mask_id
tokens[i_corrupted] = torch.randint(
self.min_word_id, self.max_word_id, (torch.sum(i_corrupted),))
return tokens, i
def mask_tokens(self, tokens):
i = torch.rand_like(
tokens, dtype=torch.float) <= self.masking_probability
masking_mask = torch.rand_like(i, dtype=torch.float)
i_masked = i & (masking_mask <= self.masking_probability)
i_corrupted = i & ((1.0 - self.corruption_probability) <= masking_mask)
tokens[i_masked] = self.mask_id
tokens[i_corrupted] = torch.randint(
self.min_word_id, self.max_word_id, (torch.sum(i_corrupted),))
return tokens, i
# def mask_raw_text(self, sentence):
# if self.tokenizer is None:
# raise Exception("A tokenizer is required to mask raw text.")
# x = self.tokenizer.encode(sentence)
# return self.mask_tokenized(x)
# def mask_positions_(self, tokens, positions):
# for i in positions:
# # Skip if already masked
# if tokens[i] == self.mask_token:
# continue
# fst = find_first_word_position(tokens, i, self.subword_prefix)
# lst = find_last_word_position(tokens, i, self.subword_prefix)
# for j in range(fst, lst + 1):
# tokens[j] = self.mask_token
# def mask_positions(self, tokens, positions):
# copy = tokens.copy()
# self.mask_positions_(copy, positions)
# return copy
# def find_first_word_position(self, tokens, pos):
# for i, tok in enumerate(tokens[pos:0:-1]):
# if not self.is_subword_token(tok):
# return pos - i
# return 0 # Should be unreachable
# def find_last_word_position(self, tokens, pos):
# for i, tok in enumerate(tokens[pos+1:]):
# if not self.is_subword_token(tok):
# return pos + i
# return len(tokens) - 1
# def is_subword_token(self, token):
# return token.startswith(self.subword_prefix)
# if __name__ == "__main__":
# sm = SentenceMasker()
# tokens = ["mini", "__be", "__r", "__t", "ça", "craint", "gra", "__ve"]
# sm.mask_positions(tokens, [1])
# print(tokens)
......@@ -54,8 +54,12 @@ if __name__ == "__main__":
docs = nlp.pipe(file_iterator, n_process=args.nprocess)
with gzip.open(args.output, "wt", encoding="UTF-8") as outfile:
outfile.writelines(
f"{_doc2str(d)}\n" for d in tqdm(docs, total=args.n))
i = 0
for i, d in enumerate(tqdm(docs, total=args.n), 1):
outfile.write(_doc2str(d))
outfile.write("\n")
print(f"I lemmatized {i} documents.")
return i
def _split(args):
train_ratio, dev_ratio, test_ratio = args.train, args.dev, args.test
......@@ -103,7 +107,9 @@ if __name__ == "__main__":
def _train_tokenizer(args):
special_tokens = ["<mask>", "<pad>", "<unk>", "<sep>"]
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
bpe = models.BPE(unk_token="<unk>")
tokenizer = Tokenizer(bpe)
tokenizer.normalizer = normalizers.Sequence([
NFD(),
StripAccents(),
......@@ -124,6 +130,29 @@ if __name__ == "__main__":
tokenizer.train_from_iterator(infile, trainer=trainer)
tokenizer.save(args.output, pretty=True)
def _full_pipeline(args):
outdir = Path(args.output).expanduser()
outdir.mkdir(exist_ok=args.force)
lemmatized_path = str(Path(outdir, "lemmatized.txt.gz"))
datasets_dir_path = Path(outdir, "datasets")
datasets_dir_path.mkdir(exist_ok=args.force)
tokenizer_path = str(Path(outdir, "tokenizer.json"))
args.output = lemmatized_path
n = _lemmatize(args)
args.input = lemmatized_path
args.output = datasets_dir_path
args.n = n
_split(args)
args.input = Path(datasets_dir_path, "train.txt.gz")
args.output = tokenizer_path
_train_tokenizer(args)
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
......@@ -149,7 +178,23 @@ if __name__ == "__main__":
train_tokenizer_parser.add_argument("output")
train_tokenizer_parser.add_argument("--vocab", default=20000, type=int)
train_tokenizer_parser.add_argument("--minfreq", default=1, type=int)
train_tokenizer_parser.add_argument(
"-s", "--subwords-prefix", default="__")
train_tokenizer_parser.add_argument("-e", "--end-suffix", default="")
train_tokenizer_parser.set_defaults(func=_train_tokenizer)
pipeline_parser = subparsers.add_parser("full-pipeline")
pipeline_parser.add_argument("input")
pipeline_parser.add_argument("output")
pipeline_parser.add_argument("-n", required=False, type=int)
pipeline_parser.add_argument("--nprocess", default=1, type=int)
pipeline_parser.add_argument("-t", "--train", default=0.8)
pipeline_parser.add_argument("-d", "--dev", default=0.1)
pipeline_parser.add_argument("-T", "--test", default=0.1)
pipeline_parser.add_argument("-f", "--force", action="store_true")
pipeline_parser.add_argument("--vocab", default=20000, type=int)
pipeline_parser.add_argument("--minfreq", default=1, type=int)
pipeline_parser.set_defaults(func=_full_pipeline)
args = parser.parse_args()
args.func(args)
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from load import *
from evaluation import *
if __name__ == "__main__":
import argparse
import sys
sys.path.append("C:/Users/gaeta/Documents/Recherches/software/minibert")
from minibert import *
parser = argparse.ArgumentParser()
parser.add_argument("datadir")
parser.add_argument("-o", "--outdir", type=str)
parser.add_argument("-d", type=int, default=64)
parser.add_argument("--bs", type=int, default=128)
parser.add_argument("-e", "--epochs", type=int, default=500)
parser.add_argument("--attention", type=str, default="self-attention")
parser.add_argument("--position", type=str, default="fixed")
parser.add_argument("--dont-normalize", action="store_true")
parser.add_argument("--activation", type=str, default="gelu")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--logdir", type=str, required=False)
parser.add_argument("-c", "--checkpoint", type=str, required=False)
parser.add_argument("--init-w2v", action="store_true")
position_mapper = {
"none": PositionalEmbeddingType.NONE,
"fixed": PositionalEmbeddingType.FIXED,
"trained": PositionalEmbeddingType.TRAINED
}
attention_mapper = {
"self-attention": AttentionType.SelfAttention,
"non-transforming": AttentionType.NonTransformingAttention
}
args = parser.parse_args()
attention_type = attention_mapper[args.attention.lower()]
position_type = position_mapper[args.position.lower()]
device = args.device
pin_memory = device != "cpu"
datasets = TrainData(args.datadir)
tokenizer = datasets.tokenizer
train_loader = DataLoader(datasets.iter_train(
), collate_fn=datasets.to_tensor, batch_size=args.bs)
dev_loader = DataLoader(datasets.iter_dev(),
collate_fn=datasets.to_tensor, batch_size=args.bs)
test_loader = DataLoader(datasets.iter_test(
), collate_fn=datasets.to_tensor, batch_size=args.bs)
mask_token = "<mask>"
pad_token = "<pad>"
mask_idx = tokenizer.token_to_id(mask_token)
pad_idx = tokenizer.token_to_id(pad_token)
if args.checkpoint is None:
configuration_dict = dict(
vocabulary=tokenizer.get_vocab(),
embedding_dim=args.d,
hidden_dim=args.d,
position_embeddings_count=datasets.max_seq_size,
position_type=position_type,
normalize_embeddings=not args.dont_normalize,
pad_idx=pad_idx,
pad_token=pad_token,
mask_idx=mask_idx,
mask_token=mask_token,
mask_prob=0.15,
keep_mask_prob=0.8,
corrupt_mask_prob=0.1,
reveal_mask_prob=0.1,
first_layer_output_size=args.d,
activation_fun=args.activation,
key_is_query=False,
attention_type=attention_type,
min_word_id=4,
mask_whole_words=True
)
configuration = MiniBertForMLMConfiguration(**configuration_dict)
model = MiniBertForMLM(configuration).to(device)
optimizer = torch.optim.Adam(model.parameters())
prev_epoch = 0
else:
checkpoint = torch.load(args.checkpoint)
configuration_dict = checkpoint["configuration"]
device = checkpoint["device"]
configuration = MiniBertForMLMConfiguration(**configuration_dict)
prev_epoch = checkpoint["epoch"]
model = MiniBertForMLM(configuration).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
run_name = "_".join([
f"d{args.d}",
args.attention,
args.position,
args.activation,
"nonorm" if args.dont_normalize else "norm"
])
if args.logdir is None:
writer = SummaryWriter(log_dir=f"runs/{run_name}")
else:
writer = SummaryWriter(log_dir=args.logdir)
writer.add_text("Command", " ".join(sys.argv))
outdir = Path(args.outdir, run_name)
if args.checkpoint is None:
outdir.mkdir()
for epoch in range(prev_epoch + 1, prev_epoch + 1 + args.epochs):
model.train()
cumloss = 0
n_train = 0
for batch_id, (x, attention_mask, wids) in enumerate(train_loader):
optimizer.zero_grad()
output, loss = model(x, attention_mask, wids)
loss.backward()
optimizer.step()
cumloss += loss.item()
writer.add_scalar("Loss/train", cumloss / len(train_loader), epoch)
print(f"EPOCH {epoch:04} - Loss: {cumloss / len(train_loader)}")
if epoch % 10 == 0:
model.eval()
k = 5
train_precision, train_tp = eval_model_k(
model, train_loader, mask_idx, pad_idx, device, k=k)
dev_precision, dev_tp = eval_model_k(
model, dev_loader, mask_idx, pad_idx, device<