Commit 53eb4577 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

Squelette pour l'entrainement

parent 41f51b95
from .corpus import *
import sys
import torch
import os
from corpus import *
from minibert import *
from minibert import *
def build_batches(seqs, bs=5):
seqs = sorted(seqs, key=len)
res = []
b = []
prev_len = len(seqs[0])
i = 0
for x in seqs:
if len(x) != prev_len or i >= bs:
prev_len = len(x)
b = []
i = 0
i = i + 1
return res
def build_tensor_batches(batches, voc2idx):
res = []
for b in batches:
tensor_batch = torch.tensor([
[voc2idx[x] for x in sent] for sent in b
], dtype=torch.long)
return res
if __name__ == "__main__":
src_dir = os.path.dirname(os.path.realpath(__file__))
crps = Corpus(os.path.join(src_dir, "trial_corpus.xml"))
voc = list(crps.compute_vocabulary())
voc2idx = {x: i for i, x in enumerate(voc)}
tokenized = [sent.split() for sent in crps]
batches = build_batches(tokenized)
train_tensors = build_tensor_batches(batches, voc2idx)
emb_dim = 50
voc_size = len(voc)
model = MiniBert(emb_dim, voc_size)
for epoch in range(10):
for x in train_tensors:
output = model(x)
import os
from xml.etree import ElementTree
__all__ = [
class Corpus:
def __init__(self, path):
self.path = path
def __iter__(self):
tree = ElementTree.parse(self.path)
root = tree.getroot()
for sentence in root.iter("sentence"):
yield sentence.attrib.get("s", "")
def compute_vocabulary(self, tokenizer=str.split):
res = set()
for s in self:
return res
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment