Finetune GPT2 on wiki-text

  1. finetune GPT2 on wiki-text
  2. 配置环境
  3. 检查 wiki 数据
  4. 检查 pretrained 模型
  5. 微调模型
  6. 保存和推送模型
  7. 测试 load from hf_hub 的模型 inference
  8. Reference

finetune GPT2 on wiki-text

1.使用 wiki 数据集微调 GPT2, 保存并 push 到 huggingface
2.本地 inference 微调后的模型推理
3.从 huggingface 拉模型推理

配置环境

%pip install transformers datasets huggingface_hub accelerate bitsandbytes trl peft loralib wandb
%pip install ipywidgets widgetsnbextension
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: transformers in /root/miniconda3/lib/python3.10/site-packages (4.51.3)
Requirement already satisfied: datasets in /root/miniconda3/lib/python3.10/site-packages (3.5.1)
Requirement already satisfied: huggingface_hub in /root/miniconda3/lib/python3.10/site-packages (0.30.2)
Requirement already satisfied: accelerate in /root/miniconda3/lib/python3.10/site-packages (1.6.0)
Requirement already satisfied: bitsandbytes in /root/miniconda3/lib/python3.10/site-packages (0.45.5)
Requirement already satisfied: trl in /root/miniconda3/lib/python3.10/site-packages (0.17.0)
Requirement already satisfied: peft in /root/miniconda3/lib/python3.10/site-packages (0.15.2)
Requirement already satisfied: loralib in /root/miniconda3/lib/python3.10/site-packages (0.1.2)
Requirement already satisfied: wandb in /root/miniconda3/lib/python3.10/site-packages (0.19.10)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (0.21.1)
Requirement already satisfied: pyyaml>=5.1 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (6.0.1)
Requirement already satisfied: requests in /root/miniconda3/lib/python3.10/site-packages (from transformers) (2.32.3)
Requirement already satisfied: filelock in /root/miniconda3/lib/python3.10/site-packages (from transformers) (3.14.0)
Requirement already satisfied: packaging>=20.0 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (24.1)
Requirement already satisfied: regex!=2019.12.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (2024.11.6)
Requirement already satisfied: safetensors>=0.4.3 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (0.5.3)
Requirement already satisfied: tqdm>=4.27 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (4.67.1)
Requirement already satisfied: numpy>=1.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (1.26.4)
Requirement already satisfied: pandas in /root/miniconda3/lib/python3.10/site-packages (from datasets) (2.2.3)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (0.3.8)
Requirement already satisfied: multiprocess<0.70.17 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (0.70.16)
Requirement already satisfied: aiohttp in /root/miniconda3/lib/python3.10/site-packages (from datasets) (3.11.18)
Requirement already satisfied: fsspec[http]<=2025.3.0,>=2023.1.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (2024.6.0)
Requirement already satisfied: xxhash in /root/miniconda3/lib/python3.10/site-packages (from datasets) (3.5.0)
Requirement already satisfied: pyarrow>=15.0.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (20.0.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/miniconda3/lib/python3.10/site-packages (from huggingface_hub) (4.12.2)
Requirement already satisfied: torch>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from accelerate) (2.1.2+cu118)
Requirement already satisfied: psutil in /root/miniconda3/lib/python3.10/site-packages (from accelerate) (5.9.8)
Requirement already satisfied: rich in /root/miniconda3/lib/python3.10/site-packages (from trl) (14.0.0)
Requirement already satisfied: click!=8.0.0,>=7.1 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (8.1.8)
Requirement already satisfied: docker-pycreds>=0.4.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (0.4.0)
Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.25.3)
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (3.1.44)
Requirement already satisfied: sentry-sdk>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.27.0)
Requirement already satisfied: pydantic<3 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.11.4)
Requirement already satisfied: platformdirs in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.2.2)
Requirement already satisfied: setuptools in /root/miniconda3/lib/python3.10/site-packages (from wandb) (65.5.0)
Requirement already satisfied: setproctitle in /root/miniconda3/lib/python3.10/site-packages (from wandb) (1.3.6)
Requirement already satisfied: six>=1.4.0 in /root/miniconda3/lib/python3.10/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (2.6.1)
Requirement already satisfied: propcache>=0.2.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (0.3.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (6.4.3)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.20.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (5.0.1)
Requirement already satisfied: aiosignal>=1.1.2 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.2)
Requirement already satisfied: attrs>=17.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.6.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /root/miniconda3/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)
Requirement already satisfied: annotated-types>=0.6.0 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (0.7.0)
Requirement already satisfied: typing-inspection>=0.4.0 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (0.4.0)
Requirement already satisfied: pydantic-core==2.33.2 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (2.33.2)
Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (1.26.13)
Requirement already satisfied: charset-normalizer<4,>=2 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (2.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (2022.12.7)
Requirement already satisfied: networkx in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.3)
Requirement already satisfied: sympy in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (1.12.1)
Requirement already satisfied: triton==2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (2.1.0)
Requirement already satisfied: jinja2 in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.1.4)
Requirement already satisfied: tzdata>=2022.7 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: pytz>=2020.1 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /root/miniconda3/lib/python3.10/site-packages (from rich->trl) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /root/miniconda3/lib/python3.10/site-packages (from rich->trl) (2.18.0)
Requirement already satisfied: smmap<6,>=3.0.1 in /root/miniconda3/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)
Requirement already satisfied: mdurl~=0.1 in /root/miniconda3/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->trl) (0.1.2)
Requirement already satisfied: MarkupSafe>=2.0 in /root/miniconda3/lib/python3.10/site-packages (from jinja2->torch>=2.0.0->accelerate) (2.1.5)
Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /root/miniconda3/lib/python3.10/site-packages (from sympy->torch>=2.0.0->accelerate) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: ipywidgets in /root/miniconda3/lib/python3.10/site-packages (8.1.3)
Requirement already satisfied: widgetsnbextension in /root/miniconda3/lib/python3.10/site-packages (4.0.11)
Requirement already satisfied: jupyterlab-widgets~=3.0.11 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (3.0.11)
Requirement already satisfied: ipython>=6.1.0 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (8.25.0)
Requirement already satisfied: traitlets>=4.3.1 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (5.14.3)
Requirement already satisfied: comm>=0.1.3 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (0.2.2)
Requirement already satisfied: pygments>=2.4.0 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (2.18.0)
Requirement already satisfied: jedi>=0.16 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.19.1)
Requirement already satisfied: typing-extensions>=4.6 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.12.2)
Requirement already satisfied: decorator in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)
Requirement already satisfied: stack-data in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.6.3)
Requirement already satisfied: exceptiongroup in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (1.2.1)
Requirement already satisfied: pexpect>4.3 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.9.0)
Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.47)
Requirement already satisfied: matplotlib-inline in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.1.7)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /root/miniconda3/lib/python3.10/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /root/miniconda3/lib/python3.10/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)
Requirement already satisfied: wcwidth in /root/miniconda3/lib/python3.10/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.13)
Requirement already satisfied: pure-eval in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)
Requirement already satisfied: asttokens>=2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.1)
Requirement already satisfied: executing>=1.2.0 in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.1)
Requirement already satisfied: six>=1.12.0 in /root/miniconda3/lib/python3.10/site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Note: you may need to restart the kernel to use updated packages.
import subprocess
import os

ON_AUTODL_ENV = True
if ON_AUTODL_ENV:
    result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
    output = result.stdout
    for line in output.splitlines():
        if '=' in line:
            var, value = line.split('=', 1)
            os.environ[var] = value
from huggingface_hub import login
login(token="your keys")
import wandb
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "finetune_gpt2_on_wiki.ipynb"  # 替换为实际文件名
wandb.init(project="finetune_gpt2_on_wiki")
wandb: Currently logged in as: goldandrabbit (goldandrabbit-g-r) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin

Tracking run with wandb version 0.19.10

Run data is saved locally in /root/z_notebooks/wandb/run-20250503_102611-0ryscs1u

Syncing run glad-feather-2 to Weights & Biases (docs)

View project at https://wandb.ai/goldandrabbit-g-r/finetune_gpt2_on_wiki

View run at https://wandb.ai/goldandrabbit-g-r/finetune_gpt2_on_wiki/runs/0ryscs1u

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset

if torch.backends.mps.is_available():
    print(f"torch.backeds.mps.is_available(): {torch.backends.mps.is_available()}")
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
elif torch.cuda.is_available():
    print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = load_dataset('Self-GRIT/wikitext-2-raw-v1-preprocessed')
torch.cuda.is_available(): True

检查 wiki 数据

dataset
DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 1835
    })
    train: Dataset({
        features: ['text'],
        num_rows: 15313
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1649
    })
})
dataset['train']['text'][0:10]
[' Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . \n',
 " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series newcomers . Character designer Raita Honjou and composer Hitoshi Sakimoto both returned from previous entries , along with Valkyria Chronicles II director Takeshi Ozawa . A large team of writers handled the script . The game 's opening theme was sung by May 'n . \n",
 " It met with positive sales in Japan , and was praised by both Japanese and western critics . After release , it received downloadable content , along with an expanded edition in November of that year . It was also adapted into manga and an original video animation series . Due to low sales of Valkyria Chronicles II , Valkyria Chronicles III was not localized , but a fan translation compatible with the game 's expanded edition was released in 2014 . Media.Vision would return to the franchise with the development of Valkyria : Azure Revolution for the PlayStation 4 . \n",
 " As with previous Valkyira Chronicles games , Valkyria Chronicles III is a tactical role @-@ playing game where players take control of a military unit and take part in missions against enemy forces . Stories are told through comic book @-@ like panels with animated character portraits , with characters speaking partially through voiced speech bubbles and partially through unvoiced text . The player progresses through a series of linear missions , gradually unlocked as maps that can be freely scanned through and replayed as they are unlocked . The route to each story location on the map varies depending on an individual player 's approach : when one option is selected , the other is sealed off to the player . Outside missions , the player characters rest in a camp , where units can be customized and character growth occurs . Alongside the main story missions are character @-@ specific sub missions relating to different squad members . After the game 's completion , additional episodes are unlocked , some of them having a higher difficulty than those found in the rest of the game . There are also love simulation elements related to the game 's two main heroines , although they take a very minor role . \n",
 ' The game \'s battle system , the BliTZ system , is carried over directly from Valkyira Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific " Potentials " , skills unique to each character . They are divided into " Personal Potential " , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and " Battle Potentials " , which are grown throughout the game and always grant boons to a character . To learn Battle Potentials , each character has a unique " Masters Table " , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special Abilities that grant them temporary boosts on the battlefield : Kurt can activate " Direct Command " and move around the battlefield without depleting his Action Point gauge , the character Reila can shift into her " Valkyria Form " and become invincible , while Imca can target multiple enemy units with her heavy weapon . \n',
 " Troops are divided into five classes : Scouts , Shocktroopers , Engineers , Lancers and Armored Soldier . Troopers can switch classes by changing their assigned weapon . Changing class does not greatly affect the stats gained while in a previous class . With victory in battle , experience points are awarded to the squad , which are distributed into five different attributes shared by the entire squad , a feature differing from early games ' method of distributing to different unit types . \n",
 ' The game takes place during the Second Europan War . Gallian Army Squad 422 , also known as " The Nameless " , are a penal military unit composed of criminals , foreign deserters , and military offenders whose real names are erased from the records and thereon officially referred to by numbers . Ordered by the Gallian military to perform the most dangerous missions that the Regular Army and Militia will not do , they are nevertheless up to the task , exemplified by their motto , Altaha Abilia , meaning " Always Ready . " The three main characters are No.7 Kurt Irving , an army officer falsely accused of treason who wishes to redeem himself ; Ace No.1 Imca , a female Darcsen heavy weapons specialist who seeks revenge against the Valkyria who destroyed her home ; and No.13 Riela Marcellis , a seemingly jinxed young woman who is unknowingly a descendant of the Valkyria . Together with their fellow squad members , these three are tasked to fight against a mysterious Imperial unit known as Calamity Raven , consisting of mostly Darcsen soldiers . \n',
 " As the Nameless officially do not exist , the upper echelons of the Gallian Army exploit the concept of plausible deniability in order to send them on missions that would otherwise make Gallia lose face in the war . While at times this works to their advantage , such as a successful incursion into Imperial territory , other orders cause certain members of the 422nd great distress . One such member , Gusurg , becomes so enraged that he abandons his post and defects into the ranks of Calamity Raven , attached to the ideal of Darcsen independence proposed by their leader , Dahau . At the same time , elements within Gallian Army Command move to erase the Nameless in order to protect their own interests . Hounded by both allies and enemies , and combined with the presence of a traitor within their ranks , the 422nd desperately move to keep themselves alive while at the same time fight to help the Gallian war effort . This continues until the Nameless 's commanding officer , Ramsey Crowe , who had been kept under house arrest , is escorted to the capital city of Randgriz in order to present evidence exonerating the weary soldiers and expose the real traitor , the Gallian General that had accused Kurt of Treason . \n",
 " Partly due to these events , and partly due to the major losses in manpower Gallia suffers towards the end of the war with the Empire , the Nameless are offered a formal position as a squad in the Gallian Army rather than serve as an anonymous shadow force . This is short @-@ lived , however , as following Maximilian 's defeat , Dahau and Calamity Raven move to activate an ancient Valkyrian super weapon within the Empire , kept secret by their benefactor . Without the support of Maximilian or the chance to prove themselves in the war with Gallia , it is Dahau 's last trump card in creating a new Darcsen nation . As an armed Gallian force invading the Empire just following the two nations ' cease @-@ fire would certainly wreck their newfound peace , Kurt decides to once again make his squad the Nameless , asking Crowe to list himself and all under his command as killed @-@ in @-@ action . Now owing allegiance to none other than themselves , the 422nd confronts Dahau and destroys the Valkyrian weapon . Each member then goes their separate ways in order to begin their lives anew . \n",
 ' Concept work for Valkyria Chronicles III began after development finished on Valkyria Chronicles II in early 2010 , with full development beginning shortly after this . The director of Valkyria Chronicles II , Takeshi Ozawa , returned to that role for Valkyria Chronicles III . Development work took approximately one year . After the release of Valkyria Chronicles II , the staff took a look at both the popular response for the game and what they wanted to do next for the series . Like its predecessor , Valkyria Chronicles III was developed for PlayStation Portable : this was due to the team wanting to refine the mechanics created for Valkyria Chronicles II , and they had not come up with the " revolutionary " idea that would warrant a new entry for the PlayStation 3 . Speaking in an interview , it was stated that the development team considered Valkyria Chronicles III to be the series \' first true sequel : while Valkyria Chronicles II had required a large amount of trial and error during development due to the platform move , the third game gave them a chance to improve upon the best parts of Valkyria Chronicles II due to being on the same platform . In addition to Sega staff from the previous games , development work was also handled by Media.Vision. The original scenario was written Kazuki Yamanobe , while the script was written by Hiroyuki Fujii , Koichi Majima , Kishiko Miyagi , Seiki Nagakawa and Takayuki Shouji . Its story was darker and more somber than that of its predecessor . \n']
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2').to(device)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    inputs = tokenizer(
        examples['text'],
        truncation=True,
        padding='max_length',
        max_length=128
    )
    inputs['labels'] = inputs['input_ids'].copy()
    return inputs
print(tokenizer)
GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
    50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)
tokenizer.pad_token
'<|endoftext|>'

检查 pretrained 模型

model
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
# 计算总参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量:{total_params / 1e6:.1f}M")

# 计算可训练参数量(若部分层被冻结)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"可训练参数量:{trainable_params / 1e6:.1f}M")
总参数量:124.4M
可训练参数量:124.4M
tokenized_datasets = dataset.map(tokenize_function, batched=True)
Map:   0%|          | 0/15313 [00:00<?, ? examples/s]
tokenized_datasets['train'][0]['text']
' Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . \n'
tokenized_datasets['train'][0]['input_ids']
[2311,
 73,
 13090,
 645,
 569,
 18354,
 7496,
 513,
 1058,
 791,
 47398,
 17740,
 357,
 4960,
 1058,
 10545,
 230,
 99,
 161,
 254,
 112,
 5641,
 44444,
 9202,
 25084,
 24440,
 12675,
 11839,
 18,
 837,
 6578,
 764,
 569,
 18354,
 7496,
 286,
 262,
 30193,
 513,
 1267,
 837,
 8811,
 6412,
 284,
 355,
 569,
 18354,
 7496,
 17740,
 6711,
 2354,
 2869,
 837,
 318,
 257,
 16106,
 2597,
 2488,
 12,
 31,
 2712,
 2008,
 983,
 4166,
 416,
 29490,
 290,
 6343,
 13,
 44206,
 329,
 262,
 14047,
 44685,
 764,
 28728,
 287,
 3269,
 2813,
 287,
 2869,
 837,
 340,
 318,
 262,
 2368,
 983,
 287,
 262,
 569,
 18354,
 7496,
 2168,
 764,
 12645,
 278,
 262,
 976,
 21748,
 286,
 16106,
 290,
 1103,
 2488,
 12,
 31,
 640,
 11327,
 355,
 663,
 27677,
 837,
 262,
 1621,
 4539,
 10730,
 284,
 262,
 717,
 983,
 290,
 5679,
 262,
 366,
 17871,
 5321,
 366,
 837]

微调模型

exp_name = 'zxc_ft_gpt2_on_wikitext2'

if ON_AUTODL_ENV:
    output_dir = f"/root/autodl-tmp/{exp_name}"
else:
    output_dir = f"/Users/gold/repos/z_notebooks/tmp_file/{exp_name}"
print(f"output_dir: {output_dir}")

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        run_name=exp_name,
        output_dir=output_dir,
        eval_strategy='epoch',
        num_train_epochs=1,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        warmup_steps=500,
        weight_decay=0.01,
        save_strategy="epoch",
        logging_dir=output_dir+'/log'
    ),
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
)
output_dir: /root/autodl-tmp/zxc_ft_gpt2_on_wikitext2
trainer.train()
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.




<div>

  <progress value='1915' max='1915' style='width:300px; height:20px; vertical-align: middle;'></progress>
  [1915/1915 01:40, Epoch 1/1]
</div>
<table border="1" class="dataframe">

Epoch Training Loss Validation Loss
1 2.915400 2.805183
</table>

TrainOutput(global_step=1915, training_loss=3.0091578232090406, metrics={'train_runtime': 101.1733, 'train_samples_per_second': 151.354, 'train_steps_per_second': 18.928, 'total_flos': 1000291221504000.0, 'train_loss': 3.0091578232090406, 'epoch': 1.0})

保存和推送模型

model_output_dir = output_dir + '/model'
# 把模型保存到指定新目录
print(f'model_output_dir:', model_output_dir)
model_output_dir: /root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model
model.save_pretrained(model_output_dir)     # 存放 config.json generation_config.json 和 model.safetensors
tokenizer.save_pretrained(model_output_dir) # 存放 merges.txt special_token_map.json tokenizer_config.json tokenizer.json vocab.json
('/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/tokenizer_config.json',
 '/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/special_tokens_map.json',
 '/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/vocab.json',
 '/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/merges.txt',
 '/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/added_tokens.json',
 '/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/tokenizer.json')

推送 model 和 tokenizer 到 hub, 参数指定 repo 的名字

push_repo_id = exp_name
model.push_to_hub(push_repo_id)
tokenizer.push_to_hub(push_repo_id)
model.safetensors:   0%|          | 0.00/498M [00:00<?, ?B/s]



README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]





CommitInfo(commit_url='https://huggingface.co/goldandrabbit/zxc_ft_gpt2_on_wikitext2/commit/6e5baf535a881c04c5a8d31f732e2d5264c2a3e6', commit_message='Upload tokenizer', commit_description='', oid='6e5baf535a881c04c5a8d31f732e2d5264c2a3e6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/goldandrabbit/zxc_ft_gpt2_on_wikitext2', endpoint='https://huggingface.co', repo_type='model', repo_id='goldandrabbit/zxc_ft_gpt2_on_wikitext2'), pr_revision=None, pr_num=None)

测试本地 inference 流程, 先读取本地的模型

local_model_path = model_output_dir

tokenizer = AutoTokenizer.from_pretrained(local_model_path)
model     = AutoModelForCausalLM.from_pretrained(local_model_path)
model
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
query_str = "Can you tell me a story"

prompt = tokenizer(query_str, return_tensors='pt')
prompt
{'input_ids': tensor([[6090,  345, 1560,  502,  257, 1621]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}
outputs = model.generate(
    **prompt,
    max_length=200,
    num_return_sequences=1,
    pad_token_id=tokenizer.eos_token_id,
    repetition_penalty=1.5,    # 对重复 token 施加惩罚(>1 生效)
    temperature=0.9,           # 降低概率分布尖锐度(0.7~1.0 平衡多样性与连贯性)
    top_k=50,                  # 仅从前50个高概率 token 中采样
    top_p=0.95,                # 从累积概率达95%的 token 集合中采样
    do_sample=True             # 启用采样模式(禁用贪婪搜索)
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text
'Can you tell me a story that I did not already know ? \n'

测试 load from hf_hub 的模型 inference

hub_path = f"goldandrabbit/{exp_name}"

tokenizer = AutoTokenizer.from_pretrained(hub_path)
model = AutoModelForCausalLM.from_pretrained(hub_path)
tokenizer_config.json:   0%|          | 0.00/507 [00:00<?, ?B/s]



vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]



merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]



tokenizer.json:   0%|          | 0.00/3.56M [00:00<?, ?B/s]



special_tokens_map.json:   0%|          | 0.00/131 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/880 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/498M [00:00<?, ?B/s]



generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]
another_query_str = "Can you tell me a story"

prompt = tokenizer(another_query_str, return_tensors='pt')
outputs = model.generate(
    **prompt,
    max_length=200,
    num_return_sequences=1,
    pad_token_id=tokenizer.eos_token_id,
    repetition_penalty=1.5,    # 对重复 token 施加惩罚(>1 生效)
    temperature=0.9,           # 降低概率分布尖锐度(0.7~1.0 平衡多样性与连贯性)
    top_k=50,                  # 仅从前50个高概率 token 中采样
    top_p=0.95,                # 从累积概率达95%的 token 集合中采样
    do_sample=True             # 启用采样模式(禁用贪婪搜索)
)
generated_text = tokenizer.decode(
    outputs[0],
    skip_special_tokens=True
)
print(generated_text)
Can you tell me a story about being kicked out of football by two people who were clearly in love for one another and there was no reason why they should be upset . " 

Reference

[1]. https://medium.com/@prashanth.ramanathan/fine-tuning-a-pre-trained-gpt-2-model-and-performing-inference-a-hands-on-guide-57c097a3b810


转载请注明来源 goldandrabbit.github.io