Commit 1c5204d7 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

code pour word2vec

parent dbbe6de1
import os
from gensim.models import Word2Vec, KeyedVectors
from datasets import *
import pandas as pd
import numpy as np
from collections import Counter
def split_neighbors(df, outdir):
themes = {
"PS": [
"<acr::vul>",
"<acr::cgc>",
"<acr::sami>",
"<lex::détaxe>",
"<acr::cnil>",
"<acr::cnam>",
"<lex::accréditation>",
"<acr::mif>",
"<lex::dédire>",
"<acr::cre>",
"<lex::camionnage>",
"<acr::cgc>",
"<acr::cram>",
"<lex::pièce_de_rechange>",
"<lex::organisme_habiliter>",
"<acr::ild>",
"<acr::cfas>",
"<acr::rpp>",
"<lex::amarrage>",
"<acr::cre>",
"<acr::cfas>",
"<acr::cnam>",
"<acr::dom>",
"<acr::agff>",
"<acr::sami>",
"<lex::réseau_principal>",
"<acr::dom>",
"<acr::cesf>",
"<acr::sncs>",
"<acr::cborm>",
"<acr::mif>",
"<acr::rpp>",
"<lex::payeur>",
],
"S": [
"<lex::wagon_isoler>",
"<lex::sncf_réseau>",
"<lex::freinage_de_service>",
"<acr::cri>",
"<acr::vic>",
"<lex::contre_sens|sen_et_contre_voie>",
"<lex::temps_de_attente>",
"<acr::bdr>",
"<acr::ajecta>",
"<acr::jama>",
"<acr::ago>",
"<acr::sprc>",
"<acr::geq>",
"<acr::cristal>",
"<acr::bdr>",
"<acr::cri>",
"<acr::gco>",
"<lex::en_amont_en_aval>",
"<acr::esf>",
"<acr::dera>",
"<acr::ago>",
"<acr::epa>",
"<acr::sprc>",
"<lex::consigne_commun_temporaire_travail>",
"<lex::alarmer_danger>",
"<acr::ccp>",
"<acr::tbl>",
"<lex::marche_au_pas>",
"<acr::sre>",
"<acr::epa>",
"<acr::esf>",
"<lex::alarme_simple>",
"<acr::eoqa>",
"<lex::zone_fret>",
"<acr::vic>",
],
"TroncCommun": [
"<acr::gpf>",
"<lex::automoteur>",
"<lex::embrancher>",
"<lex::expéditeur>",
"<acr::est>",
"<acr::sud>",
"<lex::mécanicien>",
"<acr::pcd>",
"<lex::repérer>",
"<acr::epsf>",
"<lex::chargement>",
"<acr::dbc>",
"<lex::agent_sédentaire>",
"<lex::calage>",
"<lex::escale>",
"<acr::rff>",
"<lex::automoteur>",
"<acr::eic>",
"<acr::rff>",
"<acr::tcc>",
"<acr::pgt>",
"<lex::obstacle>",
"<lex::tâche_de_sécurité>",
"<acr::pgt>",
"<acr::prg>",
"<lex::gare_origine>",
"<acr::chsct>",
"<lex::attestation>",
"<lex::agent_formation>",
"<lex::zone_dangereux>",
"<lex::wagon>",
"<acr::eev>",
"<acr::rst>",
"<acr::rst>",
"<lex::train_de_voyageur>",
"<acr::dcf>",
"<lex::cabine_de_conduite>",
"<lex::agent_du_train>",
"<acr::rfn>",
"<lex::régularité>",
"<lex::opérateur>",
"<lex::agent_de_accompagnement>",
"<acr::suivi>",
"<lex::canton>",
"<lex::installation_terminal_embrancher>",
"<acr::prg>",
"<acr::cle>",
"<acr::rfn>",
"<acr::cosec>",
"<acr::eex>",
"<acr::pcd>",
"<acr::gpf>",
"<acr::kvb>",
"<acr::tvm>",
"<acr::prg>",
"<lex::qualité>",
"<acr::tvm>",
"<acr::eic>",
"<acr::ccl>",
"<acr::pmv>",
"<lex::norme>",
"<acr::eex>",
"<lex::arrière>",
"<acr::cle>",
"<acr::datzd>",
"<acr::ccl>",
"<acr::rst>",
"<acr::lgv>",
"<lex::voie_banaliser>",
"<acr::eev>",
"<acr::est>",
"<acr::gsm>",
"<lex::locotracteur>",
"<lex::stationnement>",
"<acr::tcc>",
"<lex::signalisation>",
"<lex::triage>",
"<acr::gid>",
"<lex::dysfonctionnement>",
"<acr::pgt>",
"<acr::sncf>",
"<acr::dcf>",
"<acr::ccl>",
"<lex::agent-circulation>",
"<acr::sud>",
"<lex::embranchement_particulier>",
"<lex::procédure>",
"<lex::locomotif>",
"<lex::convoi>",
"<lex::double_voie>",
"<acr::pmv>",
"<lex::mise_en_marche>",
"<acr::tbl>",
"<lex::standard>",
"<acr::sgtc>",
"<acr::fer>",
]
}
for th, x in themes.items():
subdf = df[df["token"].isin(x)]
outpath = str(Path(outdir, f"{th}_word2vec.csv"))
subdf.sample(frac=1).to_csv(outpath, index=False)
if __name__ == "__main__":
max_seq_size = 256
seq_stride = 256
wv_path = "output/w2v/vectors.w2v"
if os.path.exists(wv_path):
wv = KeyedVectors.load(wv_path)
else:
dataset = SncfDataset("../minibert-sncf/data", max_seq_size, seq_stride, split_sentences=True)
model = Word2Vec(sentences=dataset, size=32, window=5, min_count=8, workers=12)
model.wv.save(wv_path)
wv = model.wv
acronyms = pd.read_csv("data/acr_eval.csv")
acronyms.dropna(inplace=True)
lexicon = pd.read_csv("data/lex_eval.csv")
lexicon.dropna(inplace=True)
counter = Counter(acronyms["Token"])
# vectors
definitions = {}
for acr, lem_definition, definition in zip(acronyms["Token"], acronyms["Lemmatized"], acronyms["Definition"]):
words = [lem for lem in lem_definition.split() if lem in wv.vocab]
if len(words) == 0:
continue
if acr not in definitions:
definitions[acr] = { "vectors": [], "definitions": [], "lemmatized": [], "words_in_def": [] }
if counter[acr] == 1:
definitions[acr]["vectors"].append(wv[acr])
definitions[acr]["definitions"].append(definition)
definitions[acr]["lemmatized"].append(lem_definition)
definitions[acr]["words_in_def"].append(words)
else:
x = np.stack([wv[w] for w in words], axis=0)
mean_vec = np.mean(x, axis=0)
definitions[acr]["vectors"].append(mean_vec)
definitions[acr]["definitions"].append(definition)
definitions[acr]["lemmatized"].append(lem_definition)
definitions[acr]["words_in_def"].append(words)
cols = {
"type": [],
"word": [],
"token": [],
"definitions": [],
"0": [],
"1": [],
"2": [],
"3": [],
}
# acronymes
def acr_tok_to_word(s):
import re
s = re.sub(r"^<acr::", "", s)
s = re.sub(r">$", "", s)
return s.upper()
def is_already_in_neighbors(n, s):
n2 = [acr_tok_to_word(x) for x in n]
return acr_tok_to_word(s) in n2
ignored = [w for w in wv.vocab if w in ("avoir", "etre", "<unk>", "<pad>", "<mask>") or len(w) < 3]
for acr, defs in definitions.items():
for i in range(len(defs["definitions"])):
cols["type"].append("acr")
cols["word"].append(acr_tok_to_word(acr))
cols["token"].append(acr)
cols["definitions"].append(defs["definitions"][i])
sims = wv.similar_by_vector(defs["vectors"][i], topn=30)
neighs = []
for w, _ in sims:
if w in ignored or w == acr or w in defs["words_in_def"][i] or "*" in w or "_" in w or is_already_in_neighbors(neighs, w):
continue
else:
neighs.append(w)
if len(neighs) == 4:
break
cols["0"].append(neighs[0])
cols["1"].append(neighs[1])
cols["2"].append(neighs[2])
cols["3"].append(neighs[3])
# lexicon
def to_lex_token(s):
if s.startswith("<lex::"):
return s
else:
return f"<lex::{s}>"
lex_tokens = [tok if len(lem.split()) > 1 else lem for (tok, lem) in zip(lexicon["Token"], lexicon["Lemmatized"])]
for k, term, lexdef, lemm in zip(lex_tokens, lexicon["Term"], lexicon["Definition"], lexicon["Lemmatized"]):
if k not in wv.vocab:
continue
v = wv[k]
sims = wv.similar_by_vector(v, topn=30)
neighs = []
for w, _ in sims:
if w in ignored or w == k:
continue
else:
neighs.append(w)
if len(neighs) == 4:
break
cols["type"].append("lex")
cols["word"].append(term)
cols["token"].append(to_lex_token(k))
cols["definitions"].append(lexdef)
cols["0"].append(neighs[0])
cols["1"].append(neighs[1])
cols["2"].append(neighs[2])
cols["3"].append(neighs[3])
csv_path = "output/w2v/neighbors.csv"
df = pd.DataFrame(cols)
df.sort_values("token")
df.to_csv(csv_path, index=False)
split_neighbors(df, "output/w2v")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment