-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathtask_text_classification_chnsenti.py
110 lines (96 loc) · 4.63 KB
/
task_text_classification_chnsenti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import csv
from torchblocks.metrics import Accuracy
from torchblocks.callback import TrainLogger
from torchblocks.trainer import TextClassifierTrainer
from torchblocks.processor import TextClassifierProcessor, InputExample
from torchblocks.utils import seed_everything, dict_to_text, build_argparse
from torchblocks.utils import prepare_device, get_checkpoints
from transformers import BertTokenizer, WEIGHTS_NAME
from model.modeling_nezha import NeZhaForSequenceClassification
from model.configuration_nezha import NeZhaConfig
MODEL_CLASSES = {
'nezha': (NeZhaConfig, NeZhaForSequenceClassification, BertTokenizer)
}
class ChnSentiProcessor(TextClassifierProcessor):
def get_labels(self):
return ["0", "1"]
def read_data(self, input_file):
"""Reads a json list file."""
with open(input_file, "r", encoding="utf-8-sig") as f:
reader = csv.reader(f, delimiter="\t", quotechar=None)
lines = []
for line in reader:
lines.append(line)
return lines
def create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[1]
text_b = None
label = line[0]
examples.append(
InputExample(guid=guid, texts=[text_a, text_b], label=label))
return examples
def main():
args = build_argparse().parse_args()
if args.model_name is None:
args.model_name = args.model_path.split("/")[-1]
args.output_dir = args.output_dir + '{}'.format(args.model_name)
os.makedirs(args.output_dir, exist_ok=True)
# output dir
prefix = "_".join([args.model_name, args.task_name])
logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)
# device
logger.info("initializing device")
args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
seed_everything(args.seed)
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
# data processor
logger.info("initializing data processor")
tokenizer = tokenizer_class.from_pretrained(args.model_path, do_lower_case=args.do_lower_case)
processor = ChnSentiProcessor(data_dir=args.data_dir, tokenizer=tokenizer, prefix=prefix)
label_list = processor.get_labels()
num_labels = len(label_list)
args.num_labels = num_labels
# model
logger.info("initializing model and config")
config = config_class.from_pretrained(args.model_path, num_labels=num_labels,
cache_dir=args.cache_dir if args.cache_dir else None)
model = model_class.from_pretrained(args.model_path, config=config)
model.to(args.device)
# trainer
logger.info("initializing traniner")
trainer = TextClassifierTrainer(logger=logger, args=args, collate_fn=processor.collate_fn,
input_keys=processor.get_input_keys(),
metrics=[Accuracy()])
# do train
if args.do_train:
train_dataset = processor.create_dataset(args.train_max_seq_length, 'train.tsv', 'train')
eval_dataset = processor.create_dataset(args.eval_max_seq_length, 'dev.tsv', 'dev')
trainer.train(model, train_dataset=train_dataset, eval_dataset=eval_dataset)
# do eval
if args.do_eval and args.local_rank in [-1, 0]:
results = {}
eval_dataset = processor.create_dataset(args.eval_max_seq_length, 'dev.tsv', 'dev')
checkpoints = [args.output_dir]
if args.eval_all_checkpoints or args.checkpoint_number > 0:
checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number, WEIGHTS_NAME)
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split("/")[-1].split("-")[-1]
model = model_class.from_pretrained(checkpoint, config=config)
model.to(args.device)
trainer.evaluate(model, eval_dataset, save_preds=True, prefix=str(global_step))
if global_step:
result = {"{}_{}".format(global_step, k): v for k, v in trainer.records['result'].items()}
results.update(result)
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
dict_to_text(output_eval_file, results)
if __name__ == "__main__":
main()