我将按照本教程训练一些模型:
https://huggingface.co/transformers/training.html
我不仅想跟踪评估损失和准确率,还想跟踪列车的损失和准确率,以监控过拟合。在Jupyter中运行代码时,我确实看到了所有的htis:
Epoch Training Loss Validation Loss Accuracy Glue
1 0.096500 0.928782 {'accuracy': 0.625} {'accuracy': 0.625, 'f1': 0.0}
2 0.096500 1.203832 {'accuracy': 0.625} {'accuracy': 0.625, 'f1': 0.0}
3 0.096500 1.643788 {'accuracy': 0.625} {'accuracy': 0.625, 'f1': 0.0}
但当我进入trainer.state.log_history时,这些东西就不在那里了。这对我来说真的没什么意义。
for obj in trainer.state.log_history:
print(obj)
{'loss': 0.0965, 'learning_rate': 4.5833333333333334e-05, 'epoch': 0.25, 'step': 1}
{'eval_loss': 0.9287818074226379, 'eval_accuracy': {'accuracy': 0.625}, 'eval_glue': {'accuracy': 0.625, 'f1': 0.0}, 'eval_runtime': 1.3266, 'eval_samples_per_second': 6.03, 'eval_steps_per_second': 0.754, 'epoch': 1.0, 'step': 4}
{'eval_loss': 1.2038320302963257, 'eval_accuracy': {'accuracy': 0.625}, 'eval_glue': {'accuracy': 0.625, 'f1': 0.0}, 'eval_runtime': 1.3187, 'eval_samples_per_second': 6.067, 'eval_steps_per_second': 0.758, 'epoch': 2.0, 'step': 8}
{'eval_loss': 1.6437877416610718, 'eval_accuracy': {'accuracy': 0.625}, 'eval_glue': {'accuracy': 0.625, 'f1': 0.0}, 'eval_runtime': 1.3931, 'eval_samples_per_second': 5.742, 'eval_steps_per_second': 0.718, 'epoch': 3.0, 'step': 12}
{'train_runtime': 20.9407, 'train_samples_per_second': 1.146, 'train_steps_per_second': 0.573, 'total_flos': 6314665328640.0, 'train_loss': 0.07855576276779175, 'epoch': 3.0, 'step': 12}
我如何把它们放回一个对象中,而不是打印出来?
谢谢
编辑:下面是可复制的代码:
import numpy as np
from datasets import load_metric, load_dataset
from transformers import TrainingArguments, AutoModelForSequenceClassification, Trainer, AutoTokenizer
from datasets import list_metrics
raw_datasets = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(8))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(8))
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
training_args = TrainingArguments("IntroToBERT", evaluation_strategy="epoch")
training_args.logging_strategy = 'step'
training_args.logging_first_step = True
training_args.logging_steps = 1
training_args.num_train_epochs = 3
training_args.per_device_train_batch_size = 2
training_args.eval_steps = 1
metrics = {}
for metric in ['accuracy','glue']:
metrics[metric] = load_metric(metric,'mrpc')
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
out = {}
for metric in metrics.keys():
out[metric] = metrics[metric].compute(predictions=predictions, references=labels)
return out
trainer = Trainer(
model=model,
args=training_args,
train_dataset=small_train_dataset,
eval_dataset=small_eval_dataset,
compute_metrics=compute_metrics,
)
trainer.train()
# here the printout is as shown
for obj in trainer.state.log_history:
print(obj)
# here the logging data is displayed
发布于 2021-10-30 08:05:23
您可以使用log_metrics
方法格式化日志,使用save_metrics
方法保存日志。代码如下:
# rest of the training args
# ...
training_args.logging_dir = 'logs' # or any dir you want to save logs
# training
train_result = trainer.train()
# compute train results
metrics = train_result.metrics
max_train_samples = len(small_train_dataset)
metrics["train_samples"] = min(max_train_samples, len(small_train_dataset))
# save train results
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
# compute evaluation results
metrics = trainer.evaluate()
max_val_samples = len(small_eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(small_eval_dataset))
# save evaluation results
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
您还可以通过将log_metrics
和save_metrics
中的split参数设置为"all"
(即trainer.save_metrics("all", metrics)
)来一次保存所有日志;但我更喜欢这种方式,因为您可以根据需要自定义结果。Here是transformers?提供的完整源代码,您可以从中阅读更多内容。
https://stackoverflow.com/questions/68806265
复制相似问题