Finetune bert using Trainer v.s. Pytorch train loop

  1. finetune bert using Trainer v.s. Pytorch Train Loop
  2. 配置环境
  3. 本地加载模型测试
  4. 分类问题的评估模板
  5. 使用原生 pytorch 微调
  6. Reference

finetune bert using Trainer v.s. Pytorch Train Loop

1.采用 huggingface trainsformer 在文本多类别分类任务上微调 bert 模型
2.对比使用封装好的 Trainer v.s. 原生 Pytorch Training Loop 的两类实现方法

配置环境

%pip install transformers datasets huggingface_hub accelerate bitsandbytes
%pip install trl peft loralib wandb evaluate
%pip install ipywidgets widgetsnbextension
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: transformers in /root/miniconda3/lib/python3.10/site-packages (4.51.3)
Requirement already satisfied: datasets in /root/miniconda3/lib/python3.10/site-packages (3.5.1)
Requirement already satisfied: huggingface_hub in /root/miniconda3/lib/python3.10/site-packages (0.30.2)
Requirement already satisfied: accelerate in /root/miniconda3/lib/python3.10/site-packages (1.6.0)
Requirement already satisfied: bitsandbytes in /root/miniconda3/lib/python3.10/site-packages (0.45.5)
Requirement already satisfied: safetensors>=0.4.3 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (0.5.3)
Requirement already satisfied: numpy>=1.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (1.26.4)
Requirement already satisfied: requests in /root/miniconda3/lib/python3.10/site-packages (from transformers) (2.32.3)
Requirement already satisfied: tqdm>=4.27 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (4.67.1)
Requirement already satisfied: filelock in /root/miniconda3/lib/python3.10/site-packages (from transformers) (3.14.0)
Requirement already satisfied: pyyaml>=5.1 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (6.0.1)
Requirement already satisfied: regex!=2019.12.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (2024.11.6)
Requirement already satisfied: packaging>=20.0 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (24.1)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (0.21.1)
Requirement already satisfied: multiprocess<0.70.17 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (0.70.16)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (0.3.8)
Requirement already satisfied: xxhash in /root/miniconda3/lib/python3.10/site-packages (from datasets) (3.5.0)
Requirement already satisfied: aiohttp in /root/miniconda3/lib/python3.10/site-packages (from datasets) (3.11.18)
Requirement already satisfied: pyarrow>=15.0.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (20.0.0)
Requirement already satisfied: pandas in /root/miniconda3/lib/python3.10/site-packages (from datasets) (2.2.3)
Requirement already satisfied: fsspec[http]<=2025.3.0,>=2023.1.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (2024.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/miniconda3/lib/python3.10/site-packages (from huggingface_hub) (4.12.2)
Requirement already satisfied: torch>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from accelerate) (2.1.2+cu118)
Requirement already satisfied: psutil in /root/miniconda3/lib/python3.10/site-packages (from accelerate) (5.9.8)
Requirement already satisfied: propcache>=0.2.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (0.3.1)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (2.6.1)
Requirement already satisfied: attrs>=17.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)
Requirement already satisfied: aiosignal>=1.1.2 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.2)
Requirement already satisfied: frozenlist>=1.1.1 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.6.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (6.4.3)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.20.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (5.0.1)
Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (1.26.13)
Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (2022.12.7)
Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (3.4)
Requirement already satisfied: charset-normalizer<4,>=2 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (2.0.4)
Requirement already satisfied: triton==2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (2.1.0)
Requirement already satisfied: sympy in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (1.12.1)
Requirement already satisfied: networkx in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.3)
Requirement already satisfied: jinja2 in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.1.4)
Requirement already satisfied: pytz>=2020.1 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: tzdata>=2022.7 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: six>=1.5 in /root/miniconda3/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /root/miniconda3/lib/python3.10/site-packages (from jinja2->torch>=2.0.0->accelerate) (2.1.5)
Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /root/miniconda3/lib/python3.10/site-packages (from sympy->torch>=2.0.0->accelerate) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: trl in /root/miniconda3/lib/python3.10/site-packages (0.17.0)
Requirement already satisfied: peft in /root/miniconda3/lib/python3.10/site-packages (0.15.2)
Requirement already satisfied: loralib in /root/miniconda3/lib/python3.10/site-packages (0.1.2)
Requirement already satisfied: wandb in /root/miniconda3/lib/python3.10/site-packages (0.19.10)
Requirement already satisfied: evaluate in /root/miniconda3/lib/python3.10/site-packages (0.4.3)
Requirement already satisfied: transformers>=4.46.0 in /root/miniconda3/lib/python3.10/site-packages (from trl) (4.51.3)
Requirement already satisfied: datasets>=3.0.0 in /root/miniconda3/lib/python3.10/site-packages (from trl) (3.5.1)
Requirement already satisfied: accelerate>=0.34.0 in /root/miniconda3/lib/python3.10/site-packages (from trl) (1.6.0)
Requirement already satisfied: rich in /root/miniconda3/lib/python3.10/site-packages (from trl) (14.0.0)
Requirement already satisfied: huggingface_hub>=0.25.0 in /root/miniconda3/lib/python3.10/site-packages (from peft) (0.30.2)
Requirement already satisfied: packaging>=20.0 in /root/miniconda3/lib/python3.10/site-packages (from peft) (24.1)
Requirement already satisfied: numpy>=1.17 in /root/miniconda3/lib/python3.10/site-packages (from peft) (1.26.4)
Requirement already satisfied: torch>=1.13.0 in /root/miniconda3/lib/python3.10/site-packages (from peft) (2.1.2+cu118)
Requirement already satisfied: safetensors in /root/miniconda3/lib/python3.10/site-packages (from peft) (0.5.3)
Requirement already satisfied: pyyaml in /root/miniconda3/lib/python3.10/site-packages (from peft) (6.0.1)
Requirement already satisfied: psutil in /root/miniconda3/lib/python3.10/site-packages (from peft) (5.9.8)
Requirement already satisfied: tqdm in /root/miniconda3/lib/python3.10/site-packages (from peft) (4.67.1)
Requirement already satisfied: docker-pycreds>=0.4.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (0.4.0)
Requirement already satisfied: pydantic<3 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.11.4)
Requirement already satisfied: setproctitle in /root/miniconda3/lib/python3.10/site-packages (from wandb) (1.3.6)
Requirement already satisfied: platformdirs in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.2.2)
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (3.1.44)
Requirement already satisfied: sentry-sdk>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.27.0)
Requirement already satisfied: setuptools in /root/miniconda3/lib/python3.10/site-packages (from wandb) (65.5.0)
Requirement already satisfied: click!=8.0.0,>=7.1 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (8.1.8)
Requirement already satisfied: requests<3,>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.32.3)
Requirement already satisfied: typing-extensions<5,>=4.4 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.12.2)
Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.25.3)
Requirement already satisfied: dill in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (0.3.8)
Requirement already satisfied: multiprocess in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (0.70.16)
Requirement already satisfied: fsspec[http]>=2021.05.0 in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (2024.6.0)
Requirement already satisfied: xxhash in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (3.5.0)
Requirement already satisfied: pandas in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (2.2.3)
Requirement already satisfied: filelock in /root/miniconda3/lib/python3.10/site-packages (from datasets>=3.0.0->trl) (3.14.0)
Requirement already satisfied: pyarrow>=15.0.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets>=3.0.0->trl) (20.0.0)
Requirement already satisfied: aiohttp in /root/miniconda3/lib/python3.10/site-packages (from datasets>=3.0.0->trl) (3.11.18)
Requirement already satisfied: six>=1.4.0 in /root/miniconda3/lib/python3.10/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /root/miniconda3/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)
Requirement already satisfied: annotated-types>=0.6.0 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (0.7.0)
Requirement already satisfied: typing-inspection>=0.4.0 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (0.4.0)
Requirement already satisfied: pydantic-core==2.33.2 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (2.33.2)
Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (2022.12.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (1.26.13)
Requirement already satisfied: charset-normalizer<4,>=2 in /root/miniconda3/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)
Requirement already satisfied: sympy in /root/miniconda3/lib/python3.10/site-packages (from torch>=1.13.0->peft) (1.12.1)
Requirement already satisfied: triton==2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from torch>=1.13.0->peft) (2.1.0)
Requirement already satisfied: jinja2 in /root/miniconda3/lib/python3.10/site-packages (from torch>=1.13.0->peft) (3.1.4)
Requirement already satisfied: networkx in /root/miniconda3/lib/python3.10/site-packages (from torch>=1.13.0->peft) (3.3)
Requirement already satisfied: regex!=2019.12.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers>=4.46.0->trl) (2024.11.6)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /root/miniconda3/lib/python3.10/site-packages (from transformers>=4.46.0->trl) (0.21.1)
Requirement already satisfied: pytz>=2020.1 in /root/miniconda3/lib/python3.10/site-packages (from pandas->evaluate) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /root/miniconda3/lib/python3.10/site-packages (from pandas->evaluate) (2025.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /root/miniconda3/lib/python3.10/site-packages (from pandas->evaluate) (2.9.0.post0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /root/miniconda3/lib/python3.10/site-packages (from rich->trl) (2.18.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /root/miniconda3/lib/python3.10/site-packages (from rich->trl) (3.0.0)
Requirement already satisfied: frozenlist>=1.1.1 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (1.6.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (6.4.3)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (1.20.0)
Requirement already satisfied: propcache>=0.2.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (0.3.1)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (2.6.1)
Requirement already satisfied: attrs>=17.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (23.2.0)
Requirement already satisfied: aiosignal>=1.1.2 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (1.3.2)
Requirement already satisfied: async-timeout<6.0,>=4.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (5.0.1)
Requirement already satisfied: smmap<6,>=3.0.1 in /root/miniconda3/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)
Requirement already satisfied: mdurl~=0.1 in /root/miniconda3/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->trl) (0.1.2)
Requirement already satisfied: MarkupSafe>=2.0 in /root/miniconda3/lib/python3.10/site-packages (from jinja2->torch>=1.13.0->peft) (2.1.5)
Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /root/miniconda3/lib/python3.10/site-packages (from sympy->torch>=1.13.0->peft) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: ipywidgets in /root/miniconda3/lib/python3.10/site-packages (8.1.3)
Requirement already satisfied: widgetsnbextension in /root/miniconda3/lib/python3.10/site-packages (4.0.11)
Requirement already satisfied: traitlets>=4.3.1 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (5.14.3)
Requirement already satisfied: jupyterlab-widgets~=3.0.11 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (3.0.11)
Requirement already satisfied: comm>=0.1.3 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (0.2.2)
Requirement already satisfied: ipython>=6.1.0 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (8.25.0)
Requirement already satisfied: decorator in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)
Requirement already satisfied: pexpect>4.3 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.9.0)
Requirement already satisfied: pygments>=2.4.0 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (2.18.0)
Requirement already satisfied: typing-extensions>=4.6 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.12.2)
Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.47)
Requirement already satisfied: matplotlib-inline in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.1.7)
Requirement already satisfied: jedi>=0.16 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.19.1)
Requirement already satisfied: stack-data in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.6.3)
Requirement already satisfied: exceptiongroup in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (1.2.1)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /root/miniconda3/lib/python3.10/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /root/miniconda3/lib/python3.10/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)
Requirement already satisfied: wcwidth in /root/miniconda3/lib/python3.10/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.13)
Requirement already satisfied: executing>=1.2.0 in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.1)
Requirement already satisfied: pure-eval in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)
Requirement already satisfied: asttokens>=2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.1)
Requirement already satisfied: six>=1.12.0 in /root/miniconda3/lib/python3.10/site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Note: you may need to restart the kernel to use updated packages.
import subprocess
import wandb
import os

ON_AUTODL_ENV = True
if ON_AUTODL_ENV:
    result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
    output = result.stdout
    for line in output.splitlines():
        if '=' in line:
            var, value = line.split('=', 1)
            os.environ[var] = value
print(f"ON_AUTODL_ENV: {ON_AUTODL_ENV}")

TRAINER_EXP_NAME = "finetune_bert_using_Trainer_vs_pytorch_train_loop"
RUN_NOTEBOOK_NAME = "finetune_bert_using_Trainer_vs_pytorch_train_loop.ipynb"
print(f'TRAINER_EXP_NAME: {TRAINER_EXP_NAME}')
print(f'RUN_NOTEBOOK_NAME: {RUN_NOTEBOOK_NAME}')

from huggingface_hub import login
import time
int_timestamp = str(int(time.time()))
print(f'int_timestamp: {int_timestamp}')

login(token="your keys")
os.environ["WANDB_NOTEBOOK_NAME"] = f"{RUN_NOTEBOOK_NAME}_{int_timestamp}.ipynb"  # 替换为实际文件名
wandb.init(project=f"{RUN_NOTEBOOK_NAME}")
ON_AUTODL_ENV: True
TRAINER_EXP_NAME: finetune_bert_using_Trainer_vs_pytorch_train_loop
RUN_NOTEBOOK_NAME: finetune_bert_using_Trainer_vs_pytorch_train_loop.ipynb
int_timestamp: 1746370141


wandb: WARNING WANDB_NOTEBOOK_NAME should be a path to a notebook file, couldn't find finetune_bert_using_Trainer_vs_pytorch_train_loop.ipynb_1746370141.ipynb.
wandb: Currently logged in as: goldandrabbit (goldandrabbit-g-r) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin

Tracking run with wandb version 0.19.10

Run data is saved locally in /root/z_notebooks/wandb/run-20250504_224903-ig9gqolf

Syncing run stellar-tauntaun-7 to Weights & Biases (docs)

View project at https://wandb.ai/goldandrabbit-g-r/finetune_bert_using_Trainer_vs_pytorch_train_loop.ipynb

View run at https://wandb.ai/goldandrabbit-g-r/finetune_bert_using_Trainer_vs_pytorch_train_loop.ipynb/runs/ig9gqolf

from datasets import load_dataset

dataset = load_dataset("yelp_review_full", cache_dir="/root/autodl-tmp/datasets")
print(dataset)
DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})
dataset["train"][1]
{'label': 1,
 'text': "Unfortunately, the frustration of being Dr. Goldberg's patient is a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff.  It seems that his staff simply never answers the phone.  It usually takes 2 hours of repeated calling to get an answer.  Who has time for that or wants to deal with it?  I have run into this problem with many other doctors and I just don't get it.  You have office workers, you have patients with medical needs, why isn't anyone answering the phone?  It's incomprehensible and not work the aggravation.  It's with regret that I feel that I have to give Dr. Goldberg 2 stars."}
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
# 简化流程采样少量数据
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(5000))
small_eval_dataset  = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
print(small_train_dataset)
print(small_eval_dataset)
print(small_train_dataset[0])
Dataset({
    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 5000
})
Dataset({
    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1000
})
{'label': 4, 'text': "I stalk this truck.  I've been to industrial parks where I pretend to be a tech worker standing in line, strip mall parking lots, and of course the farmer's market.  The bowls are so so absolutely divine.  The owner is super friendly and he makes each bowl by hand with an incredible amount of pride.  You gotta eat here guys!!!", 'input_ids': [101, 146, 27438, 1142, 4202, 119, 146, 112, 1396, 1151, 1106, 3924, 8412, 1187, 146, 9981, 1106, 1129, 170, 13395, 7589, 2288, 1107, 1413, 117, 6322, 8796, 5030, 7424, 117, 1105, 1104, 1736, 1103, 9230, 112, 188, 2319, 119, 1109, 20400, 1132, 1177, 1177, 7284, 10455, 119, 1109, 3172, 1110, 7688, 4931, 1105, 1119, 2228, 1296, 7329, 1118, 1289, 1114, 1126, 10965, 2971, 1104, 8188, 119, 1192, 13224, 3940, 1303, 3713, 106, 106, 106, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    "google-bert/bert-base-cased",
    num_labels=5
)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)
from transformers import Trainer, TrainingArguments

if ON_AUTODL_ENV:
    output_dir = f"/root/autodl-tmp/{TRAINER_EXP_NAME}"
else:
    output_dir = f"/Users/gold/repos/z_notebooks/tmp_file/{TRAINER_EXP_NAME}"
print(f"output_dir: {output_dir}")

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        run_name=TRAINER_EXP_NAME,
        output_dir=output_dir,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=8,
        learning_rate=5e-5,
        num_train_epochs=3,
        logging_strategy="steps",
        logging_steps=0.05,
        logging_dir=os.path.join(output_dir, "logs"),
    ),
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)
output_dir: /root/autodl-tmp/finetune_bert_using_Trainer_vs_pytorch_train_loop
trainer.train()
<div>

  <progress value='939' max='939' style='width:300px; height:20px; vertical-align: middle;'></progress>
  [939/939 03:10, Epoch 3/3]
</div>
<table border="1" class="dataframe">

Step Training Loss
47 1.488200 94 1.131200 141 1.130800 188 1.078700 235 1.061300 282 1.024400 329 0.951700 376 0.800300 423 0.789600 470 0.790900 517 0.764000 564 0.701000 611 0.773500 658 0.545700 705 0.525800 752 0.527300 799 0.401600 846 0.437600 893 0.408700
</table>

TrainOutput(global_step=939, training_loss=0.7916636461901843, metrics={'train_runtime': 190.9139, 'train_samples_per_second': 78.569, 'train_steps_per_second': 4.918, 'total_flos': 3946772136960000.0, 'train_loss': 0.7916636461901843, 'epoch': 3.0})
trainer.evaluate(small_eval_dataset)
[125/125 00:04]
{'eval_loss': 1.0120339393615723,
 'eval_accuracy': 0.635,
 'eval_runtime': 4.6131,
 'eval_samples_per_second': 216.775,
 'eval_steps_per_second': 27.097,
 'epoch': 3.0}
model_output_dir = output_dir + '/model'
print(f'model_output_dir:', model_output_dir)
model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)
push_repo_id = TRAINER_EXP_NAME
print(f'push_repo_id:', push_repo_id)
model.push_to_hub(push_repo_id)
tokenizer.push_to_hub(push_repo_id)
model_output_dir: /root/autodl-tmp/finetune_bert_using_Trainer_vs_pytorch_train_loop/model
push_repo_id: finetune_bert_using_Trainer_vs_pytorch_train_loop



model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]


No files have been modified since last commit. Skipping to prevent empty commit.





CommitInfo(commit_url='https://huggingface.co/goldandrabbit/finetune_bert_using_Trainer_vs_pytorch_train_loop/commit/797b90ad9ab50d08544eb738fadbfd997b0ab492', commit_message='Upload tokenizer', commit_description='', oid='797b90ad9ab50d08544eb738fadbfd997b0ab492', pr_url=None, repo_url=RepoUrl('https://huggingface.co/goldandrabbit/finetune_bert_using_Trainer_vs_pytorch_train_loop', endpoint='https://huggingface.co', repo_type='model', repo_id='goldandrabbit/finetune_bert_using_Trainer_vs_pytorch_train_loop'), pr_revision=None, pr_num=None)

本地加载模型测试

1.测试单条样本的预测结果: 直接将 tokenize 之后的数据作为参数传给模型
2.测试数据集的预测结果: 需要实例化一个新的 trainer, 这里我们写成 eval_trainer, 然后再调用 eval_trainer.evaluate()

from transformers import AutoTokenizer, AutoModelForSequenceClassification

local_model_path = model_output_dir
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
model     = AutoModelForSequenceClassification.from_pretrained(local_model_path)
import torch
import torch.nn.functional as F
s = "The was awesome and I loved it"
tt = tokenizer(s, return_tensors="pt", padding=True, truncation=True)
model.eval()
with torch.no_grad():
    outputs=model(**tt)
print(outputs)

logits = outputs.logits
print("Logits:", logits)

# 输出概率
probabilities = F.softmax(logits, dim=-1)
print("Probabilities:", probabilities)

# 输出对应的类别
predicted_class = torch.argmax(probabilities, dim=-1)
print("Predicted Class:", predicted_class.item())
SequenceClassifierOutput(loss=None, logits=tensor([[-2.4610, -2.6981, -1.1480,  3.0334,  3.9818]]), hidden_states=None, attentions=None)
Logits: tensor([[-2.4610, -2.6981, -1.1480,  3.0334,  3.9818]])
Probabilities: tensor([[0.0011, 0.0009, 0.0042, 0.2775, 0.7163]])
Predicted Class: 4

分类问题的评估模板

1.使用的是 evaluate 类, 分别 load 多种指标: acc, precision, recall, f1
2.写一个 compute_metrics() 传给 trainer 类 compute_metrics 参数
3.实例化 trainer (和训练过程一样)
4.调用 eval_trainer.evaluate() 函数

accuracy_metric  = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric    = evaluate.load("recall")
f1_metric        = evaluate.load("f1")

def compute_classification_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # 计算多分类指标(需指定 average 参数)
    results = {
        "accuracy": accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"],
        "precision": precision_metric.compute(predictions=predictions, references=labels, average="macro")["precision"],
        "recall": recall_metric.compute(predictions=predictions, references=labels, average="macro")["recall"],
        "f1": f1_metric.compute(predictions=predictions, references=labels, average="macro")["f1"]
    }
    return results
eval_trainer = Trainer(
    model=model,
    args=TrainingArguments(
        run_name="local_eval_dataset",
        output_dir=output_dir + "./local_eval_results",
        per_device_eval_batch_size=16
    ),
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_classification_metrics
)
results = eval_trainer.evaluate()
[63/63 00:04]
from pprint import pprint
pprint(results)
{'eval_accuracy': 0.635,
 'eval_f1': 0.6364144170295373,
 'eval_loss': 1.0120338201522827,
 'eval_model_preparation_time': 0.0068,
 'eval_precision': 0.6412483673175235,
 'eval_recall': 0.6334288150192273,
 'eval_runtime': 4.434,
 'eval_samples_per_second': 225.53,
 'eval_steps_per_second': 14.208}

使用原生 pytorch 微调

1.加载预训练模型和使用 trainer 无任何区别, 这里我们继续复用已经 tokenized 之后的 tokenized_datasets
2.加载数据采用 pytorch DataLoader, DataLoader 控制 batchsize 和 shuffle
3.训练过程需要手动写 training loop, 标准的教程在 https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # 打开训练模式, 对于 batch_norm 或者 dropout 是有用的
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # 计算预估结果
        pred = model(X)
        # 计算 loss
        loss = loss_fn(pred, y)
        # 计算反向传播
        loss.backward()
        # 梯度更新
        optimizer.step()
        # 防止梯度累积​
        optimizer.zero_grad()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn):
    # 打开 eval 模式 
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    # eval 保证无梯度更新, 开启 torch.no_grad()
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets
DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 50000
    })
})
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(5000))
small_eval_dataset  = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
import torch
from torch.utils.data import DataLoader

traindataloader=DataLoader(small_train_dataset, batch_size=16, shuffle=True)
testdataloader=DataLoader(small_eval_dataset, batch_size=8)
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.





BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=5, bias=True)
)
from transformers import get_scheduler
from torch.optim import AdamW, SGD

optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(traindataloader)

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

# 打开训练模式
model.train()
for epoch in range(num_epochs):
    total_loss = 0

    for batch in traindataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)

        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        progress_bar.update(1)

    avg_train_loss = total_loss / len(traindataloader)
    print(f"Epoch {epoch+1} | Avg Loss: {avg_train_loss:.6f}")
  0%|          | 0/939 [00:00<?, ?it/s]


Epoch 1 | Avg Loss: 1.149697
Epoch 2 | Avg Loss: 0.798091
Epoch 3 | Avg Loss: 0.504651
import evaluate

metric = evaluate.load("accuracy")

# 打开 eval 模式
model.eval()
for batch in testdataloader:
    b = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**b)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()
{'accuracy': 0.611}

Reference

[1]. https://medium.com/codex/fine-tune-bert-for-text-classification-cef7a1d6cdf1


转载请注明来源 goldandrabbit.github.io