finetune GPT2 on wiki-text
1.使用 wiki 数据集微调 GPT2, 保存并 push 到 huggingface
2.本地 inference 微调后的模型推理
3.从 huggingface 拉模型推理
配置环境
%pip install transformers datasets huggingface_hub accelerate bitsandbytes trl peft loralib wandb
%pip install ipywidgets widgetsnbextension
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: transformers in /root/miniconda3/lib/python3.10/site-packages (4.51.3)
Requirement already satisfied: datasets in /root/miniconda3/lib/python3.10/site-packages (3.5.1)
Requirement already satisfied: huggingface_hub in /root/miniconda3/lib/python3.10/site-packages (0.30.2)
Requirement already satisfied: accelerate in /root/miniconda3/lib/python3.10/site-packages (1.6.0)
Requirement already satisfied: bitsandbytes in /root/miniconda3/lib/python3.10/site-packages (0.45.5)
Requirement already satisfied: trl in /root/miniconda3/lib/python3.10/site-packages (0.17.0)
Requirement already satisfied: peft in /root/miniconda3/lib/python3.10/site-packages (0.15.2)
Requirement already satisfied: loralib in /root/miniconda3/lib/python3.10/site-packages (0.1.2)
Requirement already satisfied: wandb in /root/miniconda3/lib/python3.10/site-packages (0.19.10)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (0.21.1)
Requirement already satisfied: pyyaml>=5.1 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (6.0.1)
Requirement already satisfied: requests in /root/miniconda3/lib/python3.10/site-packages (from transformers) (2.32.3)
Requirement already satisfied: filelock in /root/miniconda3/lib/python3.10/site-packages (from transformers) (3.14.0)
Requirement already satisfied: packaging>=20.0 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (24.1)
Requirement already satisfied: regex!=2019.12.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (2024.11.6)
Requirement already satisfied: safetensors>=0.4.3 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (0.5.3)
Requirement already satisfied: tqdm>=4.27 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (4.67.1)
Requirement already satisfied: numpy>=1.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (1.26.4)
Requirement already satisfied: pandas in /root/miniconda3/lib/python3.10/site-packages (from datasets) (2.2.3)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (0.3.8)
Requirement already satisfied: multiprocess<0.70.17 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (0.70.16)
Requirement already satisfied: aiohttp in /root/miniconda3/lib/python3.10/site-packages (from datasets) (3.11.18)
Requirement already satisfied: fsspec[http]<=2025.3.0,>=2023.1.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (2024.6.0)
Requirement already satisfied: xxhash in /root/miniconda3/lib/python3.10/site-packages (from datasets) (3.5.0)
Requirement already satisfied: pyarrow>=15.0.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (20.0.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/miniconda3/lib/python3.10/site-packages (from huggingface_hub) (4.12.2)
Requirement already satisfied: torch>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from accelerate) (2.1.2+cu118)
Requirement already satisfied: psutil in /root/miniconda3/lib/python3.10/site-packages (from accelerate) (5.9.8)
Requirement already satisfied: rich in /root/miniconda3/lib/python3.10/site-packages (from trl) (14.0.0)
Requirement already satisfied: click!=8.0.0,>=7.1 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (8.1.8)
Requirement already satisfied: docker-pycreds>=0.4.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (0.4.0)
Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.25.3)
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (3.1.44)
Requirement already satisfied: sentry-sdk>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.27.0)
Requirement already satisfied: pydantic<3 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.11.4)
Requirement already satisfied: platformdirs in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.2.2)
Requirement already satisfied: setuptools in /root/miniconda3/lib/python3.10/site-packages (from wandb) (65.5.0)
Requirement already satisfied: setproctitle in /root/miniconda3/lib/python3.10/site-packages (from wandb) (1.3.6)
Requirement already satisfied: six>=1.4.0 in /root/miniconda3/lib/python3.10/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (2.6.1)
Requirement already satisfied: propcache>=0.2.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (0.3.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (6.4.3)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.20.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (5.0.1)
Requirement already satisfied: aiosignal>=1.1.2 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.2)
Requirement already satisfied: attrs>=17.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.6.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /root/miniconda3/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)
Requirement already satisfied: annotated-types>=0.6.0 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (0.7.0)
Requirement already satisfied: typing-inspection>=0.4.0 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (0.4.0)
Requirement already satisfied: pydantic-core==2.33.2 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (2.33.2)
Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (1.26.13)
Requirement already satisfied: charset-normalizer<4,>=2 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (2.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (2022.12.7)
Requirement already satisfied: networkx in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.3)
Requirement already satisfied: sympy in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (1.12.1)
Requirement already satisfied: triton==2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (2.1.0)
Requirement already satisfied: jinja2 in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.1.4)
Requirement already satisfied: tzdata>=2022.7 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: pytz>=2020.1 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /root/miniconda3/lib/python3.10/site-packages (from rich->trl) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /root/miniconda3/lib/python3.10/site-packages (from rich->trl) (2.18.0)
Requirement already satisfied: smmap<6,>=3.0.1 in /root/miniconda3/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)
Requirement already satisfied: mdurl~=0.1 in /root/miniconda3/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->trl) (0.1.2)
Requirement already satisfied: MarkupSafe>=2.0 in /root/miniconda3/lib/python3.10/site-packages (from jinja2->torch>=2.0.0->accelerate) (2.1.5)
Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /root/miniconda3/lib/python3.10/site-packages (from sympy->torch>=2.0.0->accelerate) (1.3.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: ipywidgets in /root/miniconda3/lib/python3.10/site-packages (8.1.3)
Requirement already satisfied: widgetsnbextension in /root/miniconda3/lib/python3.10/site-packages (4.0.11)
Requirement already satisfied: jupyterlab-widgets~=3.0.11 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (3.0.11)
Requirement already satisfied: ipython>=6.1.0 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (8.25.0)
Requirement already satisfied: traitlets>=4.3.1 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (5.14.3)
Requirement already satisfied: comm>=0.1.3 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (0.2.2)
Requirement already satisfied: pygments>=2.4.0 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (2.18.0)
Requirement already satisfied: jedi>=0.16 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.19.1)
Requirement already satisfied: typing-extensions>=4.6 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.12.2)
Requirement already satisfied: decorator in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)
Requirement already satisfied: stack-data in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.6.3)
Requirement already satisfied: exceptiongroup in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (1.2.1)
Requirement already satisfied: pexpect>4.3 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.9.0)
Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.47)
Requirement already satisfied: matplotlib-inline in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.1.7)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /root/miniconda3/lib/python3.10/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /root/miniconda3/lib/python3.10/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)
Requirement already satisfied: wcwidth in /root/miniconda3/lib/python3.10/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.13)
Requirement already satisfied: pure-eval in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)
Requirement already satisfied: asttokens>=2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.1)
Requirement already satisfied: executing>=1.2.0 in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.1)
Requirement already satisfied: six>=1.12.0 in /root/miniconda3/lib/python3.10/site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
import subprocess
import os
ON_AUTODL_ENV = True
if ON_AUTODL_ENV:
result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
if '=' in line:
var, value = line.split('=', 1)
os.environ[var] = value
from huggingface_hub import login
login(token="your keys")
import wandb
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "finetune_gpt2_on_wiki.ipynb" # 替换为实际文件名
wandb.init(project="finetune_gpt2_on_wiki")
[34m[1mwandb[0m: Currently logged in as: [33mgoldandrabbit[0m ([33mgoldandrabbit-g-r[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
Tracking run with wandb version 0.19.10
Run data is saved locally in /root/z_notebooks/wandb/run-20250503_102611-0ryscs1u
Syncing run glad-feather-2 to Weights & Biases (docs)
View project at https://wandb.ai/goldandrabbit-g-r/finetune_gpt2_on_wiki
View run at https://wandb.ai/goldandrabbit-g-r/finetune_gpt2_on_wiki/runs/0ryscs1u
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
if torch.backends.mps.is_available():
print(f"torch.backeds.mps.is_available(): {torch.backends.mps.is_available()}")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
elif torch.cuda.is_available():
print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = load_dataset('Self-GRIT/wikitext-2-raw-v1-preprocessed')
torch.cuda.is_available(): True
检查 wiki 数据
dataset
DatasetDict({
test: Dataset({
features: ['text'],
num_rows: 1835
})
train: Dataset({
features: ['text'],
num_rows: 15313
})
validation: Dataset({
features: ['text'],
num_rows: 1649
})
})
dataset['train']['text'][0:10]
[' Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . \n',
" The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series newcomers . Character designer Raita Honjou and composer Hitoshi Sakimoto both returned from previous entries , along with Valkyria Chronicles II director Takeshi Ozawa . A large team of writers handled the script . The game 's opening theme was sung by May 'n . \n",
" It met with positive sales in Japan , and was praised by both Japanese and western critics . After release , it received downloadable content , along with an expanded edition in November of that year . It was also adapted into manga and an original video animation series . Due to low sales of Valkyria Chronicles II , Valkyria Chronicles III was not localized , but a fan translation compatible with the game 's expanded edition was released in 2014 . Media.Vision would return to the franchise with the development of Valkyria : Azure Revolution for the PlayStation 4 . \n",
" As with previous Valkyira Chronicles games , Valkyria Chronicles III is a tactical role @-@ playing game where players take control of a military unit and take part in missions against enemy forces . Stories are told through comic book @-@ like panels with animated character portraits , with characters speaking partially through voiced speech bubbles and partially through unvoiced text . The player progresses through a series of linear missions , gradually unlocked as maps that can be freely scanned through and replayed as they are unlocked . The route to each story location on the map varies depending on an individual player 's approach : when one option is selected , the other is sealed off to the player . Outside missions , the player characters rest in a camp , where units can be customized and character growth occurs . Alongside the main story missions are character @-@ specific sub missions relating to different squad members . After the game 's completion , additional episodes are unlocked , some of them having a higher difficulty than those found in the rest of the game . There are also love simulation elements related to the game 's two main heroines , although they take a very minor role . \n",
' The game \'s battle system , the BliTZ system , is carried over directly from Valkyira Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific " Potentials " , skills unique to each character . They are divided into " Personal Potential " , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and " Battle Potentials " , which are grown throughout the game and always grant boons to a character . To learn Battle Potentials , each character has a unique " Masters Table " , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special Abilities that grant them temporary boosts on the battlefield : Kurt can activate " Direct Command " and move around the battlefield without depleting his Action Point gauge , the character Reila can shift into her " Valkyria Form " and become invincible , while Imca can target multiple enemy units with her heavy weapon . \n',
" Troops are divided into five classes : Scouts , Shocktroopers , Engineers , Lancers and Armored Soldier . Troopers can switch classes by changing their assigned weapon . Changing class does not greatly affect the stats gained while in a previous class . With victory in battle , experience points are awarded to the squad , which are distributed into five different attributes shared by the entire squad , a feature differing from early games ' method of distributing to different unit types . \n",
' The game takes place during the Second Europan War . Gallian Army Squad 422 , also known as " The Nameless " , are a penal military unit composed of criminals , foreign deserters , and military offenders whose real names are erased from the records and thereon officially referred to by numbers . Ordered by the Gallian military to perform the most dangerous missions that the Regular Army and Militia will not do , they are nevertheless up to the task , exemplified by their motto , Altaha Abilia , meaning " Always Ready . " The three main characters are No.7 Kurt Irving , an army officer falsely accused of treason who wishes to redeem himself ; Ace No.1 Imca , a female Darcsen heavy weapons specialist who seeks revenge against the Valkyria who destroyed her home ; and No.13 Riela Marcellis , a seemingly jinxed young woman who is unknowingly a descendant of the Valkyria . Together with their fellow squad members , these three are tasked to fight against a mysterious Imperial unit known as Calamity Raven , consisting of mostly Darcsen soldiers . \n',
" As the Nameless officially do not exist , the upper echelons of the Gallian Army exploit the concept of plausible deniability in order to send them on missions that would otherwise make Gallia lose face in the war . While at times this works to their advantage , such as a successful incursion into Imperial territory , other orders cause certain members of the 422nd great distress . One such member , Gusurg , becomes so enraged that he abandons his post and defects into the ranks of Calamity Raven , attached to the ideal of Darcsen independence proposed by their leader , Dahau . At the same time , elements within Gallian Army Command move to erase the Nameless in order to protect their own interests . Hounded by both allies and enemies , and combined with the presence of a traitor within their ranks , the 422nd desperately move to keep themselves alive while at the same time fight to help the Gallian war effort . This continues until the Nameless 's commanding officer , Ramsey Crowe , who had been kept under house arrest , is escorted to the capital city of Randgriz in order to present evidence exonerating the weary soldiers and expose the real traitor , the Gallian General that had accused Kurt of Treason . \n",
" Partly due to these events , and partly due to the major losses in manpower Gallia suffers towards the end of the war with the Empire , the Nameless are offered a formal position as a squad in the Gallian Army rather than serve as an anonymous shadow force . This is short @-@ lived , however , as following Maximilian 's defeat , Dahau and Calamity Raven move to activate an ancient Valkyrian super weapon within the Empire , kept secret by their benefactor . Without the support of Maximilian or the chance to prove themselves in the war with Gallia , it is Dahau 's last trump card in creating a new Darcsen nation . As an armed Gallian force invading the Empire just following the two nations ' cease @-@ fire would certainly wreck their newfound peace , Kurt decides to once again make his squad the Nameless , asking Crowe to list himself and all under his command as killed @-@ in @-@ action . Now owing allegiance to none other than themselves , the 422nd confronts Dahau and destroys the Valkyrian weapon . Each member then goes their separate ways in order to begin their lives anew . \n",
' Concept work for Valkyria Chronicles III began after development finished on Valkyria Chronicles II in early 2010 , with full development beginning shortly after this . The director of Valkyria Chronicles II , Takeshi Ozawa , returned to that role for Valkyria Chronicles III . Development work took approximately one year . After the release of Valkyria Chronicles II , the staff took a look at both the popular response for the game and what they wanted to do next for the series . Like its predecessor , Valkyria Chronicles III was developed for PlayStation Portable : this was due to the team wanting to refine the mechanics created for Valkyria Chronicles II , and they had not come up with the " revolutionary " idea that would warrant a new entry for the PlayStation 3 . Speaking in an interview , it was stated that the development team considered Valkyria Chronicles III to be the series \' first true sequel : while Valkyria Chronicles II had required a large amount of trial and error during development due to the platform move , the third game gave them a chance to improve upon the best parts of Valkyria Chronicles II due to being on the same platform . In addition to Sega staff from the previous games , development work was also handled by Media.Vision. The original scenario was written Kazuki Yamanobe , while the script was written by Hiroyuki Fujii , Koichi Majima , Kishiko Miyagi , Seiki Nagakawa and Takayuki Shouji . Its story was darker and more somber than that of its predecessor . \n']
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2').to(device)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
inputs = tokenizer(
examples['text'],
truncation=True,
padding='max_length',
max_length=128
)
inputs['labels'] = inputs['input_ids'].copy()
return inputs
print(tokenizer)
GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)
tokenizer.pad_token
'<|endoftext|>'
检查 pretrained 模型
model
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-11): 12 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D(nf=2304, nx=768)
(c_proj): Conv1D(nf=768, nx=768)
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D(nf=3072, nx=768)
(c_proj): Conv1D(nf=768, nx=3072)
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
# 计算总参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量:{total_params / 1e6:.1f}M")
# 计算可训练参数量(若部分层被冻结)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"可训练参数量:{trainable_params / 1e6:.1f}M")
总参数量:124.4M
可训练参数量:124.4M
tokenized_datasets = dataset.map(tokenize_function, batched=True)
Map: 0%| | 0/15313 [00:00<?, ? examples/s]
tokenized_datasets['train'][0]['text']
' Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . \n'
tokenized_datasets['train'][0]['input_ids']
[2311,
73,
13090,
645,
569,
18354,
7496,
513,
1058,
791,
47398,
17740,
357,
4960,
1058,
10545,
230,
99,
161,
254,
112,
5641,
44444,
9202,
25084,
24440,
12675,
11839,
18,
837,
6578,
764,
569,
18354,
7496,
286,
262,
30193,
513,
1267,
837,
8811,
6412,
284,
355,
569,
18354,
7496,
17740,
6711,
2354,
2869,
837,
318,
257,
16106,
2597,
2488,
12,
31,
2712,
2008,
983,
4166,
416,
29490,
290,
6343,
13,
44206,
329,
262,
14047,
44685,
764,
28728,
287,
3269,
2813,
287,
2869,
837,
340,
318,
262,
2368,
983,
287,
262,
569,
18354,
7496,
2168,
764,
12645,
278,
262,
976,
21748,
286,
16106,
290,
1103,
2488,
12,
31,
640,
11327,
355,
663,
27677,
837,
262,
1621,
4539,
10730,
284,
262,
717,
983,
290,
5679,
262,
366,
17871,
5321,
366,
837]
微调模型
exp_name = 'zxc_ft_gpt2_on_wikitext2'
if ON_AUTODL_ENV:
output_dir = f"/root/autodl-tmp/{exp_name}"
else:
output_dir = f"/Users/gold/repos/z_notebooks/tmp_file/{exp_name}"
print(f"output_dir: {output_dir}")
trainer = Trainer(
model=model,
args=TrainingArguments(
run_name=exp_name,
output_dir=output_dir,
eval_strategy='epoch',
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
save_strategy="epoch",
logging_dir=output_dir+'/log'
),
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['validation'],
)
output_dir: /root/autodl-tmp/zxc_ft_gpt2_on_wikitext2
trainer.train()
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
<div>
<progress value='1915' max='1915' style='width:300px; height:20px; vertical-align: middle;'></progress>
[1915/1915 01:40, Epoch 1/1]
</div>
<table border="1" class="dataframe">
Epoch
Training Loss
Validation Loss
</table>
TrainOutput(global_step=1915, training_loss=3.0091578232090406, metrics={'train_runtime': 101.1733, 'train_samples_per_second': 151.354, 'train_steps_per_second': 18.928, 'total_flos': 1000291221504000.0, 'train_loss': 3.0091578232090406, 'epoch': 1.0})
保存和推送模型
model_output_dir = output_dir + '/model'
# 把模型保存到指定新目录
print(f'model_output_dir:', model_output_dir)
model_output_dir: /root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model
model.save_pretrained(model_output_dir) # 存放 config.json generation_config.json 和 model.safetensors
tokenizer.save_pretrained(model_output_dir) # 存放 merges.txt special_token_map.json tokenizer_config.json tokenizer.json vocab.json
('/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/tokenizer_config.json',
'/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/special_tokens_map.json',
'/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/vocab.json',
'/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/merges.txt',
'/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/added_tokens.json',
'/root/autodl-tmp/zxc_ft_gpt2_on_wikitext2/model/tokenizer.json')
推送 model 和 tokenizer 到 hub, 参数指定 repo 的名字
push_repo_id = exp_name
model.push_to_hub(push_repo_id)
tokenizer.push_to_hub(push_repo_id)
model.safetensors: 0%| | 0.00/498M [00:00<?, ?B/s]
README.md: 0%| | 0.00/5.17k [00:00<?, ?B/s]
CommitInfo(commit_url='https://huggingface.co/goldandrabbit/zxc_ft_gpt2_on_wikitext2/commit/6e5baf535a881c04c5a8d31f732e2d5264c2a3e6', commit_message='Upload tokenizer', commit_description='', oid='6e5baf535a881c04c5a8d31f732e2d5264c2a3e6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/goldandrabbit/zxc_ft_gpt2_on_wikitext2', endpoint='https://huggingface.co', repo_type='model', repo_id='goldandrabbit/zxc_ft_gpt2_on_wikitext2'), pr_revision=None, pr_num=None)
测试本地 inference 流程, 先读取本地的模型
local_model_path = model_output_dir
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
model = AutoModelForCausalLM.from_pretrained(local_model_path)
model
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-11): 12 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D(nf=2304, nx=768)
(c_proj): Conv1D(nf=768, nx=768)
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D(nf=3072, nx=768)
(c_proj): Conv1D(nf=768, nx=3072)
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
query_str = "Can you tell me a story"
prompt = tokenizer(query_str, return_tensors='pt')
prompt
{'input_ids': tensor([[6090, 345, 1560, 502, 257, 1621]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}
outputs = model.generate(
**prompt,
max_length=200,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.5, # 对重复 token 施加惩罚(>1 生效)
temperature=0.9, # 降低概率分布尖锐度(0.7~1.0 平衡多样性与连贯性)
top_k=50, # 仅从前50个高概率 token 中采样
top_p=0.95, # 从累积概率达95%的 token 集合中采样
do_sample=True # 启用采样模式(禁用贪婪搜索)
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text
'Can you tell me a story that I did not already know ? \n'
测试 load from hf_hub 的模型 inference
hub_path = f"goldandrabbit/{exp_name}"
tokenizer = AutoTokenizer.from_pretrained(hub_path)
model = AutoModelForCausalLM.from_pretrained(hub_path)
tokenizer_config.json: 0%| | 0.00/507 [00:00<?, ?B/s]
vocab.json: 0%| | 0.00/798k [00:00<?, ?B/s]
merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]
tokenizer.json: 0%| | 0.00/3.56M [00:00<?, ?B/s]
special_tokens_map.json: 0%| | 0.00/131 [00:00<?, ?B/s]
config.json: 0%| | 0.00/880 [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/498M [00:00<?, ?B/s]
generation_config.json: 0%| | 0.00/119 [00:00<?, ?B/s]
another_query_str = "Can you tell me a story"
prompt = tokenizer(another_query_str, return_tensors='pt')
outputs = model.generate(
**prompt,
max_length=200,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.5, # 对重复 token 施加惩罚(>1 生效)
temperature=0.9, # 降低概率分布尖锐度(0.7~1.0 平衡多样性与连贯性)
top_k=50, # 仅从前50个高概率 token 中采样
top_p=0.95, # 从累积概率达95%的 token 集合中采样
do_sample=True # 启用采样模式(禁用贪婪搜索)
)
generated_text = tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
print(generated_text)
Can you tell me a story about being kicked out of football by two people who were clearly in love for one another and there was no reason why they should be upset . "
Reference
转载请注明来源 goldandrabbit.github.io