Commit 9e88de8c authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

script to generate train/test corpora

parent ca8d895c
from train_semeval import *
import argparse
import torch
from collections import Counter
from pathlib import Path
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("corpus", help="Input corpus")
parser.add_argument("-s", "--simplify", action="store_true")
parser.add_argument(
"-o", "--output", help="Output directory", required=False)
args = parser.parse_args()
crps = load_corpus(args.corpus, args.simplify)
dtm = crps.dtm()
voc = crps.vocabulary()
voc2idx = {x: i for i, x in enumerate(voc)}
freqs = torch.sum(dtm, axis=0)
at_least_2 = [
doc for doc in crps
if all(freqs[voc2idx[tok]] >= 2 for tok in doc.split())
]
test_idx = []
for i, doc in enumerate(crps):
tokens = doc.split()
counter = Counter(tokens)
new_freqs = [freqs[voc2idx[t]] - c for t, c in counter.items()]
if all(x > 0 for x in new_freqs):
test_idx.append(i)
for t, c in counter.items():
freqs[voc2idx[t]] -= c
train_crps = []
test_crps = []
for i, doc in enumerate(crps):
if i in test_idx:
test_crps.append(doc)
else:
train_crps.append(doc)
if args.output is None:
print("TRAIN")
print("-----")
for x in train_crps:
print(x)
print("\n")
print("TEST")
print("----")
for x in test_crps:
print(x)
else:
path = Path(args.output)
path.mkdir(exist_ok=True)
with open(str(Path(path, "train.txt")), "wt", encoding="UTF-8") as f:
f.write("\n".join(train_crps))
with open(str(Path(path, "test.txt")), "wt", encoding="UTF-8") as f:
f.write("\n".join(test_crps))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment