Commit 228248f4 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

wsd avec MiniBERT

parent d9128d02
from pathlib import Path
import math
import shutil
import json
import random
import torch
from import DataLoader, IterableDataset, Dataset
from tokenizers import Tokenizer
class LineDataset(IterableDataset):
def __init__(self, path, lines_count=None):
self.path = str(path)
self.lines_count = lines_count
self.vocabulary = None
def __len__(self):
if self.lines_count is None:
with open(self.path, "rt", encoding="UTF-8") as f:
i = 0
for i, _ in enumerate(f, 1):
self.lines_count = i
return self.lines_count
def __iter__(self):
with open(self.path, "rt", encoding="UTF-8") as f:
for l in f:
yield l
def get_vocabulary(self):
if self.vocabulary is None:
self.vocabulary = set()
for x in self:
return sorted(self.vocabulary)
class JsonDataset(Dataset):
def __init__(self, path):
self.path = str(path)
self.vocabulary = None
with open(self.path, 'r') as f:
json_data =
self.json = json.loads(json_data)
def __getitem__(self, i):
return self.json[i]["phrase_lem"]
def __len__(self):
return len(self.json)
class SncfDataset(IterableDataset):
def __init__(self, path, size, stride=None):
self.path = Path(path).expanduser()
self.files = list(self.path.glob("**/*.txt"))
self.vocabulary = None
self.size = size
self.stride = stride or size
self._data = None
def __iter__(self):
if self._data is None:
self._data = []
for fp in self.files:
with open(str(fp), "rt", encoding="UTF-8") as f:
for l in f:
spl = l.strip().split()
for i in range(0, len(spl), self.stride):
self._data.append(" ".join(spl[i:i + self.size]))
return iter(self._data)
def split(self, output_dir, train=0.6, dev=0.2, test=0.2):
assert train + dev + test == 1
n = len(self.files)
last_train = math.floor(train * n)
last_dev = last_train + math.floor(dev * n)
shuffled_files = self.files.copy()
# train_files, dev_files, test_files
splitted_files = shuffled_files[:last_train], shuffled_files[last_train:last_dev], shuffled_files[last_dev:]
output_dir = Path(output_dir).expanduser()
dirs = (
Path(output_dir, "train"),
Path(output_dir, "dev"),
Path(output_dir, "test")
for d, files in zip(dirs, splitted_files):
for f in files:
shutil.copy(f, d)
return tuple(SncfDataset(d) for d in dirs)
def get_vocabulary(self):
if self.vocabulary is None:
self.vocabulary = set()
for x in self:
return sorted(self.vocabulary)
def read_splitted_datasets(input_dir, *args, **kwargs):
paths = (
Path(input_dir, x)
for x in ("train", "dev", "test")
return tuple(SncfDataset(p, *args, **kwargs) for p in paths)
def split_sentence(sentence, size, stride=None):
if stride is None:
stride = size
words = sentence.split()
return [" ".join(words[i:i+size]) for i in range(0, len(words), stride)]
def split_sentence_batch(sentence, size, stride=None):
res = []
for s in sentence:
res.extend(split_sentence(s, size, stride=stride))
return res
class SncfCollater:
def __init__(self, tokenizer, pad_token="<pad>"):
self.tokenizer = tokenizer
self.pad_id = self.tokenizer.token_to_id(pad_token)
self.tokenizer.enable_padding(pad_id=self.pad_id, pad_token=pad_token)
def __call__(self, sentences):
encoded = self.tokenizer.encode_batch(sentences)
n = max(len(x.ids) for x in encoded)
shape = (len(encoded), n)
sequence_tensor = torch.tensor(
[x.ids for x in encoded], dtype=torch.long)
attention_mask_tensor = torch.tensor(
[x.attention_mask for x in encoded], dtype=torch.float)
word_ids_tensor = torch.tensor(
[[-1 if id is None else id for id in x.word_ids] for x in encoded], dtype=torch.long)
# sequence_tensor = torch.full(shape, self.pad_id, dtype=torch.long)
# attention_mask_tensor = torch.zeros(shape, dtype=torch.float)
# word_ids_tensor = torch.full(shape, -1, dtype=torch.long)
# for i, x in enumerate(encoded):
# x_len = len(x.ids)
# sequence_tensor[i, :x_len] = torch.tensor(x.ids[:x_len])
# attention_mask_tensor[i, :x_len] = torch.tensor(
# x.attention_mask[:x_len])
# for j, wid in enumerate(x.word_ids[:x_len]):
# if wid is None:
# break
# word_ids_tensor[i, j] = wid
# assert torch.equal(sequence_tensor, sequence_tensor2)
# assert torch.equal(attention_mask_tensor, attention_mask_tensor2)
# assert torch.equal(word_ids_tensor, word_ids_tensor2)
return sequence_tensor, attention_mask_tensor, word_ids_tensor
class AcronymsCollater:
def __init__(self, tokenizer, pad_token="<pad>"):
self.sncf_collater = SncfCollater(tokenizer, pad_token=pad_token)
def __call__(self, sentences):
tokenizer = self.sncf_collater.tokenizer
sequence_tensor, attention_mask_tensor, word_ids_tensor = self.sncf_collater(
acr_data = []
for i, seqt in enumerate(sequence_tensor):
acr_for_i = []
for ii, x in enumerate(seqt):
x = x.item()
tok = tokenizer.id_to_token(x)
if tok.startswith("<acr::"):
acr_for_i.append((tok, x, ii))
return sequence_tensor, attention_mask_tensor, word_ids_tensor, acr_data, sentences
import torch
import os
import json
import pandas as pd
import seaborn
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise
from tokenizers import Tokenizer
from import DataLoader
from operator import is_not
from functools import partial
from collections import defaultdict
from datasets import *
from minibert import *
import sys
from minibert import *
# Retourne les embeddings de definitions
# {
# acr_token: {
# "vectors": matrice, une ligne = un embedding,
# "definitions": liste de chaines de caractères
# "lemmatized": liste de chaines de caractères
# }
# }
def get_acr_def_dict(model, tokenizer, acronyms, device="cuda"):
definitions_file = "output/minibert/definitions.tar"
if os.path.exists(definitions_file):
definitions = torch.load(definitions_file)
definitions = {}
for acr, lem_definition, definition in zip(acronyms["Token"], acronyms["Lemmatized"], acronyms["Definition"]):
encoded = tokenizer.encode(lem_definition)
x = torch.tensor([encoded.ids], device=device)
attention_mask = torch.tensor([encoded.attention_mask], device=device)
# wids = torch.tensor([encoded.word_ids], device=device)
output = model.minibert(x, attention_mask)
mean_vec = torch.mean(output, dim=1)
if acr not in definitions:
definitions[acr] = { "vectors": [], "definitions": [], "lemmatized": [] }
for acr in definitions.keys():
definitions[acr]["vectors"] = torch.tensor(definitions[acr]["vectors"], device=device), definitions_file)
return definitions
def minibert_wsd(args):
device = "cuda"
pin_memory = device != "cpu"
mask_token = "<mask>"
pad_token = "<pad>"
max_seq_size = 32
seq_stride = 32
bs = 32
acronyms = pd.read_csv("data/sens.csv")
tokenizer = Tokenizer.from_file("../minibert-sncf/data/tokenizer.json")
collater = SncfCollater(tokenizer, pad_token)
checkpoint = torch.load(args.model, map_location=torch.device(device))
configuration_dict = checkpoint["configuration"]
device = checkpoint["device"]
configuration = MiniBertForMLMConfiguration(**configuration_dict)
model = MiniBertForMLM(configuration).to(device)
definitions = get_acr_def_dict(model, tokenizer, acronyms, device)
json_path = "data/annotation.json"
with open(json_path, "r", encoding="UTF-8") as f:
json_data =
annotated = json.loads(json_data)
for isent, sent in enumerate(annotated):
x, attention_mask, wid = collater([sent["phrase_lem"]])
x =
attention_mask =
embeddings = model.minibert(x, attention_mask)
for iacr, tok in enumerate(sent["acronymes"]):
i = tok["position"]
acr = tok["token"]
v = embeddings[0, i, :].view(-1, 1)
dists = torch.matmul(definitions[acr]["vectors"], v).view(-1)
idef = torch.argmax(dists).item()
annotated[isent]["acronymes"][iacr]["prediction"] = definitions[acr]["lemmatized"][idef]
predictions_path = "output/minibert/predictions.json"
with open(predictions_path, "w", encoding="UTF-8") as f:
f.write(json.dumps(annotated, indent=4, ensure_ascii=False))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
wsd_parser = subparsers.add_parser("wsd")
wsd_parser.add_argument("-m", "--model", default="../minibert-sncf/models/d64_self-attention_fixed_gelu_norm/checkpoint-00100.tar")
args = parser.parse_args()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment