qlora_finetune_codegen350M
1.采用 QLoRA 微调模型 Salesforce/codegen-350M-mono, 是个代码生成的模型, model_card: https://huggingface.co/Salesforce/codegen-350M-mono
2.微调数据集采用 iamtarun/python_code_instructions_18k_alpaca, dataset card: https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca
3.如何配置 LoRA 微调的方式? 相比原有的微调多一个 LoRAConfig
4.微调效果评估: 对比 base_model 和 lora_merged_model 生成代码的质量, 看下微调是否带来生成质量的提升
# 配置环境
%pip install transformers datasets huggingface_hub accelerate bitsandbytes
%pip install trl peft loralib wandb evaluate scikit-learn
%pip install ipywidgets widgetsnbextension
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: transformers in /root/miniconda3/lib/python3.10/site-packages (4.51.3)
Requirement already satisfied: datasets in /root/miniconda3/lib/python3.10/site-packages (3.5.1)
Requirement already satisfied: huggingface_hub in /root/miniconda3/lib/python3.10/site-packages (0.30.2)
Requirement already satisfied: accelerate in /root/miniconda3/lib/python3.10/site-packages (1.6.0)
Requirement already satisfied: bitsandbytes in /root/miniconda3/lib/python3.10/site-packages (0.45.5)
Requirement already satisfied: regex!=2019.12.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (2024.11.6)
Requirement already satisfied: numpy>=1.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (1.26.4)
Requirement already satisfied: pyyaml>=5.1 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (6.0.1)
Requirement already satisfied: packaging>=20.0 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (24.1)
Requirement already satisfied: filelock in /root/miniconda3/lib/python3.10/site-packages (from transformers) (3.14.0)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (0.21.1)
Requirement already satisfied: requests in /root/miniconda3/lib/python3.10/site-packages (from transformers) (2.32.3)
Requirement already satisfied: tqdm>=4.27 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (4.67.1)
Requirement already satisfied: safetensors>=0.4.3 in /root/miniconda3/lib/python3.10/site-packages (from transformers) (0.5.3)
Requirement already satisfied: pyarrow>=15.0.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (20.0.0)
Requirement already satisfied: xxhash in /root/miniconda3/lib/python3.10/site-packages (from datasets) (3.5.0)
Requirement already satisfied: fsspec[http]<=2025.3.0,>=2023.1.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (2024.6.0)
Requirement already satisfied: aiohttp in /root/miniconda3/lib/python3.10/site-packages (from datasets) (3.11.18)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /root/miniconda3/lib/python3.10/site-packages (from datasets) (2.2.3)
Requirement already satisfied: multiprocess<0.70.17 in /root/miniconda3/lib/python3.10/site-packages (from datasets) (0.70.16)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/miniconda3/lib/python3.10/site-packages (from huggingface_hub) (4.12.2)
Requirement already satisfied: psutil in /root/miniconda3/lib/python3.10/site-packages (from accelerate) (5.9.8)
Requirement already satisfied: torch>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from accelerate) (2.1.2+cu118)
Requirement already satisfied: frozenlist>=1.1.1 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.6.0)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (2.6.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.20.0)
Requirement already satisfied: attrs>=17.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (5.0.1)
Requirement already satisfied: aiosignal>=1.1.2 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.2)
Requirement already satisfied: multidict<7.0,>=4.5 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (6.4.3)
Requirement already satisfied: propcache>=0.2.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets) (0.3.1)
Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (1.26.13)
Requirement already satisfied: charset-normalizer<4,>=2 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/lib/python3.10/site-packages (from requests->transformers) (2022.12.7)
Requirement already satisfied: triton==2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (2.1.0)
Requirement already satisfied: jinja2 in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.1.4)
Requirement already satisfied: sympy in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (1.12.1)
Requirement already satisfied: networkx in /root/miniconda3/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.3)
Requirement already satisfied: pytz>=2020.1 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: tzdata>=2022.7 in /root/miniconda3/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: six>=1.5 in /root/miniconda3/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /root/miniconda3/lib/python3.10/site-packages (from jinja2->torch>=2.0.0->accelerate) (2.1.5)
Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /root/miniconda3/lib/python3.10/site-packages (from sympy->torch>=2.0.0->accelerate) (1.3.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: trl in /root/miniconda3/lib/python3.10/site-packages (0.17.0)
Requirement already satisfied: peft in /root/miniconda3/lib/python3.10/site-packages (0.15.2)
Requirement already satisfied: loralib in /root/miniconda3/lib/python3.10/site-packages (0.1.2)
Requirement already satisfied: wandb in /root/miniconda3/lib/python3.10/site-packages (0.19.10)
Requirement already satisfied: evaluate in /root/miniconda3/lib/python3.10/site-packages (0.4.3)
Requirement already satisfied: scikit-learn in /root/miniconda3/lib/python3.10/site-packages (1.6.1)
Requirement already satisfied: transformers>=4.46.0 in /root/miniconda3/lib/python3.10/site-packages (from trl) (4.51.3)
Requirement already satisfied: accelerate>=0.34.0 in /root/miniconda3/lib/python3.10/site-packages (from trl) (1.6.0)
Requirement already satisfied: datasets>=3.0.0 in /root/miniconda3/lib/python3.10/site-packages (from trl) (3.5.1)
Requirement already satisfied: rich in /root/miniconda3/lib/python3.10/site-packages (from trl) (14.0.0)
Requirement already satisfied: packaging>=20.0 in /root/miniconda3/lib/python3.10/site-packages (from peft) (24.1)
Requirement already satisfied: tqdm in /root/miniconda3/lib/python3.10/site-packages (from peft) (4.67.1)
Requirement already satisfied: torch>=1.13.0 in /root/miniconda3/lib/python3.10/site-packages (from peft) (2.1.2+cu118)
Requirement already satisfied: safetensors in /root/miniconda3/lib/python3.10/site-packages (from peft) (0.5.3)
Requirement already satisfied: psutil in /root/miniconda3/lib/python3.10/site-packages (from peft) (5.9.8)
Requirement already satisfied: huggingface_hub>=0.25.0 in /root/miniconda3/lib/python3.10/site-packages (from peft) (0.30.2)
Requirement already satisfied: pyyaml in /root/miniconda3/lib/python3.10/site-packages (from peft) (6.0.1)
Requirement already satisfied: numpy>=1.17 in /root/miniconda3/lib/python3.10/site-packages (from peft) (1.26.4)
Requirement already satisfied: requests<3,>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.32.3)
Requirement already satisfied: platformdirs in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.2.2)
Requirement already satisfied: pydantic<3 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.11.4)
Requirement already satisfied: sentry-sdk>=2.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (2.27.0)
Requirement already satisfied: typing-extensions<5,>=4.4 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.12.2)
Requirement already satisfied: docker-pycreds>=0.4.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (0.4.0)
Requirement already satisfied: click!=8.0.0,>=7.1 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (8.1.8)
Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (4.25.3)
Requirement already satisfied: setproctitle in /root/miniconda3/lib/python3.10/site-packages (from wandb) (1.3.6)
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /root/miniconda3/lib/python3.10/site-packages (from wandb) (3.1.44)
Requirement already satisfied: setuptools in /root/miniconda3/lib/python3.10/site-packages (from wandb) (65.5.0)
Requirement already satisfied: pandas in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (2.2.3)
Requirement already satisfied: dill in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (0.3.8)
Requirement already satisfied: xxhash in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (3.5.0)
Requirement already satisfied: multiprocess in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (0.70.16)
Requirement already satisfied: fsspec[http]>=2021.05.0 in /root/miniconda3/lib/python3.10/site-packages (from evaluate) (2024.6.0)
Requirement already satisfied: threadpoolctl>=3.1.0 in /root/miniconda3/lib/python3.10/site-packages (from scikit-learn) (3.6.0)
Requirement already satisfied: joblib>=1.2.0 in /root/miniconda3/lib/python3.10/site-packages (from scikit-learn) (1.5.0)
Requirement already satisfied: scipy>=1.6.0 in /root/miniconda3/lib/python3.10/site-packages (from scikit-learn) (1.15.2)
Requirement already satisfied: filelock in /root/miniconda3/lib/python3.10/site-packages (from datasets>=3.0.0->trl) (3.14.0)
Requirement already satisfied: pyarrow>=15.0.0 in /root/miniconda3/lib/python3.10/site-packages (from datasets>=3.0.0->trl) (20.0.0)
Requirement already satisfied: aiohttp in /root/miniconda3/lib/python3.10/site-packages (from datasets>=3.0.0->trl) (3.11.18)
Requirement already satisfied: six>=1.4.0 in /root/miniconda3/lib/python3.10/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /root/miniconda3/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)
Requirement already satisfied: pydantic-core==2.33.2 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (2.33.2)
Requirement already satisfied: typing-inspection>=0.4.0 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (0.4.0)
Requirement already satisfied: annotated-types>=0.6.0 in /root/miniconda3/lib/python3.10/site-packages (from pydantic<3->wandb) (0.7.0)
Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (3.4)
Requirement already satisfied: charset-normalizer<4,>=2 in /root/miniconda3/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (2022.12.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (1.26.13)
Requirement already satisfied: sympy in /root/miniconda3/lib/python3.10/site-packages (from torch>=1.13.0->peft) (1.12.1)
Requirement already satisfied: jinja2 in /root/miniconda3/lib/python3.10/site-packages (from torch>=1.13.0->peft) (3.1.4)
Requirement already satisfied: triton==2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from torch>=1.13.0->peft) (2.1.0)
Requirement already satisfied: networkx in /root/miniconda3/lib/python3.10/site-packages (from torch>=1.13.0->peft) (3.3)
Requirement already satisfied: regex!=2019.12.17 in /root/miniconda3/lib/python3.10/site-packages (from transformers>=4.46.0->trl) (2024.11.6)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /root/miniconda3/lib/python3.10/site-packages (from transformers>=4.46.0->trl) (0.21.1)
Requirement already satisfied: python-dateutil>=2.8.2 in /root/miniconda3/lib/python3.10/site-packages (from pandas->evaluate) (2.9.0.post0)
Requirement already satisfied: tzdata>=2022.7 in /root/miniconda3/lib/python3.10/site-packages (from pandas->evaluate) (2025.2)
Requirement already satisfied: pytz>=2020.1 in /root/miniconda3/lib/python3.10/site-packages (from pandas->evaluate) (2025.2)
Requirement already satisfied: markdown-it-py>=2.2.0 in /root/miniconda3/lib/python3.10/site-packages (from rich->trl) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /root/miniconda3/lib/python3.10/site-packages (from rich->trl) (2.18.0)
Requirement already satisfied: frozenlist>=1.1.1 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (1.6.0)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (2.6.1)
Requirement already satisfied: attrs>=17.3.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (23.2.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (6.4.3)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (1.20.0)
Requirement already satisfied: aiosignal>=1.1.2 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (1.3.2)
Requirement already satisfied: async-timeout<6.0,>=4.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (5.0.1)
Requirement already satisfied: propcache>=0.2.0 in /root/miniconda3/lib/python3.10/site-packages (from aiohttp->datasets>=3.0.0->trl) (0.3.1)
Requirement already satisfied: smmap<6,>=3.0.1 in /root/miniconda3/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)
Requirement already satisfied: mdurl~=0.1 in /root/miniconda3/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->trl) (0.1.2)
Requirement already satisfied: MarkupSafe>=2.0 in /root/miniconda3/lib/python3.10/site-packages (from jinja2->torch>=1.13.0->peft) (2.1.5)
Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /root/miniconda3/lib/python3.10/site-packages (from sympy->torch>=1.13.0->peft) (1.3.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: ipywidgets in /root/miniconda3/lib/python3.10/site-packages (8.1.3)
Requirement already satisfied: widgetsnbextension in /root/miniconda3/lib/python3.10/site-packages (4.0.11)
Requirement already satisfied: jupyterlab-widgets~=3.0.11 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (3.0.11)
Requirement already satisfied: comm>=0.1.3 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (0.2.2)
Requirement already satisfied: traitlets>=4.3.1 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (5.14.3)
Requirement already satisfied: ipython>=6.1.0 in /root/miniconda3/lib/python3.10/site-packages (from ipywidgets) (8.25.0)
Requirement already satisfied: exceptiongroup in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (1.2.1)
Requirement already satisfied: decorator in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)
Requirement already satisfied: typing-extensions>=4.6 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.12.2)
Requirement already satisfied: pexpect>4.3 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.9.0)
Requirement already satisfied: jedi>=0.16 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.19.1)
Requirement already satisfied: matplotlib-inline in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.1.7)
Requirement already satisfied: stack-data in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.6.3)
Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.47)
Requirement already satisfied: pygments>=2.4.0 in /root/miniconda3/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (2.18.0)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /root/miniconda3/lib/python3.10/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /root/miniconda3/lib/python3.10/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)
Requirement already satisfied: wcwidth in /root/miniconda3/lib/python3.10/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.13)
Requirement already satisfied: asttokens>=2.1.0 in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.1)
Requirement already satisfied: executing>=1.2.0 in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.1)
Requirement already satisfied: pure-eval in /root/miniconda3/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)
Requirement already satisfied: six>=1.12.0 in /root/miniconda3/lib/python3.10/site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
import subprocess
import os
import wandb
ON_AUTODL_ENV = True
if ON_AUTODL_ENV:
result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
if '=' in line:
var, value = line.split('=', 1)
os.environ[var] = value
print(f"ON_AUTODL_ENV: {ON_AUTODL_ENV}")
TRAINER_EXP_NAME = "qlora_finetune_codegen350M"
RUN_NOTEBOOK_NAME = "qlora_finetune_codegen350M.ipynb"
from huggingface_hub import login
login(token="your keys")
os.environ["WANDB_NOTEBOOK_NAME"] = f"{RUN_NOTEBOOK_NAME}.ipynb" # 替换为实际文件名
wandb.init(project=f"{RUN_NOTEBOOK_NAME}")
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
if torch.backends.mps.is_available():
print(f"torch.backeds.mps.is_available(): {torch.backends.mps.is_available()}")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
elif torch.cuda.is_available():
print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ON_AUTODL_ENV: True
[34m[1mwandb[0m: [33mWARNING[0m WANDB_NOTEBOOK_NAME should be a path to a notebook file, couldn't find qlora_finetune_codegen350M.ipynb.ipynb.
[34m[1mwandb[0m: Currently logged in as: [33mgoldandrabbit[0m ([33mgoldandrabbit-g-r[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
Tracking run with wandb version 0.19.10
Run data is saved locally in /root/z_notebooks/wandb/run-20250505_140817-0q8dlrxv
Syncing run elegant-commander-6 to Weights & Biases (docs)
View project at https://wandb.ai/goldandrabbit-g-r/qlora_finetune_codegen350M.ipynb
View run at https://wandb.ai/goldandrabbit-g-r/qlora_finetune_codegen350M.ipynb/runs/0q8dlrxv
torch.cuda.is_available(): True
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, AutoPeftModelForCausalLM
微调数据集: python python_code_instructions_18k_alpaca
iamtarun/python_code_instructions_18k_alpaca 是个代码生成数据集,
dataset card: https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca
instruction
Write a Python code to get the third largest element in a given row.
input
[12, 13, 13, 45, 22, 99]
output
def third_largest(lst): if len(lst) < 3: return distinct = [] for i in lst: if i not in distinct: distinct.append(i) distinct.sort(reverse=True) return distinct[2]
prompt:
Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Write a Python code to get the third largest element in a given row. ### Input: [12, 13, 13, 45, 22, 99] ### Output: def third_largest(lst): if len(lst) < 3: return distinct = [] for i in lst: if i not in distinct: distinct.append(i) distinct.sort(reverse=True) return distinct[2]
dataset = load_dataset("iamtarun/python_code_instructions_18k_alpaca", split="train")
dataset = dataset.remove_columns(["prompt"])
dataset
Dataset({
features: ['instruction', 'input', 'output'],
num_rows: 18612
})
dataset[0]
{'instruction': 'Create a function to calculate the sum of a sequence of integers.',
'input': '[1, 2, 3, 4, 5]',
'output': '# Python code\ndef sum_sequence(sequence):\n sum = 0\n for num in sequence:\n sum += num\n return sum'}
dataset[1]
{'instruction': 'Generate a Python code for crawling a website for a specific type of data.',
'input': 'website: www.example.com \ndata to crawl: phone numbers',
'output': "import requests\nimport re\n\ndef crawl_website_for_phone_numbers(website):\n response = requests.get(website)\n phone_numbers = re.findall('\\d{3}-\\d{3}-\\d{4}', response.text)\n return phone_numbers\n \nif __name__ == '__main__':\n print(crawl_website_for_phone_numbers('www.example.com'))"}
dataset[15]
{'instruction': 'Collate a machine learning model in Python that distinguishes between cats and dogs.',
'input': 'A dataset of 800 images of cats and dogs',
'output': "import numpy as np\nimport keras\nfrom keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, Flatten\nfrom keras.models import Sequential\nfrom keras.preprocessing.image import ImageDataGenerator\n\n# Create the neural network model\nmodel = Sequential()\n\n# Input layer\nmodel.add(Conv2D(32, (3, 3), input_shape = (64, 64, 3), activation = 'relu'))\n\n# Hidden layers\nmodel.add(MaxPooling2D(pool_size = (2, 2)))\nmodel.add(Dropout(0.3))\nmodel.add(Conv2D(64, (3, 3), activation = 'relu'))\nmodel.add(MaxPooling2D(pool_size = (2, 2)))\nmodel.add(Dropout(0.3))\nmodel.add(Conv2D(64, (3, 3), activation = 'relu'))\nmodel.add(MaxPooling2D(pool_size = (2, 2)))\nmodel.add(Dropout(0.3))\n\n# Output layer\nmodel.add(Flatten())\nmodel.add(Dense(units = 128, activation = 'relu'))\nmodel.add(Dense(units = 1, activation = 'sigmoid'))\n\n# Compile the model\nmodel.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])\n\n# Create data generator\ndatagen = ImageDataGenerator(rescale = 1./255, shear_range = 0.2, zoom_range = 0.2, horizontal_flip = True)\n\n# Fit the model\ntrain_generator = datagen.flow_from_directory(directory = '/path/to/dataset', target_size = (64, 64), color_mode = 'rgb', class_mode = 'binary', batch_size = 32)\nmodel.fit_generator(generator = train_generator, steps_per_epoch = 800, epochs = 5, validation_data = test_generator, validation_steps = 200)"}
搭建 QLoRA 微调 Pipeline
1.定义 base_model, 其实就是我们 load 预训练的模型之外包一层量化相关配置
这里通过配置 bnb_config 生效 4 bit 量化
2.定义 lora_config: 实例化一个 LoraConfig()
(i). 配置 lora low-rank 的 r
(ii). 配置 task_type, 我们选择 CASUAL_LM
3.Token config, 常规配置
4.定义 SFTTrainer, SFTTrainer 是来自于 trl 这个库
这里根据我们要做的代码生成任务 SFTTrainer 配置带一个代码生成的 prompt 模板, 这样我们就不需要手动处理数据适配生成模板, 然后开启微调 train()
5.merge adapter, 保存模型
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
from accelerate import Accelerator
# 在训练代码前添加 accelerator 初始化
accelerator = Accelerator()
print(f'accelerator.process_index: {accelerator.process_index}')
accelerator.process_index: 0
微调模型 Salesforce/codegen-350M-mono
CodeGen 是个生成代码的模型: 350M/2B/6B/16B, 我们先微调 350M
model_card: https://huggingface.co/Salesforce/codegen-350M-mono
model = AutoModelForCausalLM.from_pretrained(
"Salesforce/codegen-350M-mono",
quantization_config=bnb_config,
# device_map="cuda" if torch.cuda.is_available() else "cpu",
# device_map=device,
device_map={"": accelerator.process_index},
use_cache=False,
trust_remote_code=True
)
model = prepare_model_for_kbit_training(model)
/root/miniconda3/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at Salesforce/codegen-350M-mono were not used when initializing CodeGenForCausalLM: ['transformer.h.0.attn.causal_mask', 'transformer.h.1.attn.causal_mask', 'transformer.h.10.attn.causal_mask', 'transformer.h.11.attn.causal_mask', 'transformer.h.12.attn.causal_mask', 'transformer.h.13.attn.causal_mask', 'transformer.h.14.attn.causal_mask', 'transformer.h.15.attn.causal_mask', 'transformer.h.16.attn.causal_mask', 'transformer.h.17.attn.causal_mask', 'transformer.h.18.attn.causal_mask', 'transformer.h.19.attn.causal_mask', 'transformer.h.2.attn.causal_mask', 'transformer.h.3.attn.causal_mask', 'transformer.h.4.attn.causal_mask', 'transformer.h.5.attn.causal_mask', 'transformer.h.6.attn.causal_mask', 'transformer.h.7.attn.causal_mask', 'transformer.h.8.attn.causal_mask', 'transformer.h.9.attn.causal_mask']
- This IS expected if you are initializing CodeGenForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CodeGenForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
model_id = 'Salesforce/codegen-350M-mono'
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
text = "def hello_world():"
input_ids = tokenizer(text, return_tensors="pt").to(device)
generated_ids = model.generate(**input_ids, max_length=128)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
def hello_world():
print("Hello World")
hello_world()
# 파이썬에서는 문자열을 반환하는 함수를 사용하여 문자열을 반환하는 함수를 사용하여 문자�
for name, module in model.named_modules():
if "attn" in name or "attention" in name:
print(f"Attention layer: {name}")
for sub_name, sub_module in module.named_modules():
print(f" - Sub-module: {sub_name}")
Attention layer: transformer.h.0.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.0.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.0.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.0.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.0.attn.out_proj
- Sub-module:
Attention layer: transformer.h.1.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.1.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.1.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.1.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.1.attn.out_proj
- Sub-module:
Attention layer: transformer.h.2.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.2.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.2.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.2.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.2.attn.out_proj
- Sub-module:
Attention layer: transformer.h.3.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.3.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.3.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.3.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.3.attn.out_proj
- Sub-module:
Attention layer: transformer.h.4.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.4.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.4.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.4.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.4.attn.out_proj
- Sub-module:
Attention layer: transformer.h.5.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.5.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.5.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.5.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.5.attn.out_proj
- Sub-module:
Attention layer: transformer.h.6.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.6.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.6.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.6.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.6.attn.out_proj
- Sub-module:
Attention layer: transformer.h.7.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.7.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.7.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.7.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.7.attn.out_proj
- Sub-module:
Attention layer: transformer.h.8.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.8.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.8.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.8.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.8.attn.out_proj
- Sub-module:
Attention layer: transformer.h.9.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.9.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.9.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.9.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.9.attn.out_proj
- Sub-module:
Attention layer: transformer.h.10.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.10.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.10.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.10.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.10.attn.out_proj
- Sub-module:
Attention layer: transformer.h.11.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.11.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.11.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.11.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.11.attn.out_proj
- Sub-module:
Attention layer: transformer.h.12.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.12.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.12.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.12.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.12.attn.out_proj
- Sub-module:
Attention layer: transformer.h.13.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.13.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.13.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.13.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.13.attn.out_proj
- Sub-module:
Attention layer: transformer.h.14.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.14.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.14.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.14.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.14.attn.out_proj
- Sub-module:
Attention layer: transformer.h.15.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.15.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.15.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.15.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.15.attn.out_proj
- Sub-module:
Attention layer: transformer.h.16.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.16.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.16.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.16.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.16.attn.out_proj
- Sub-module:
Attention layer: transformer.h.17.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.17.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.17.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.17.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.17.attn.out_proj
- Sub-module:
Attention layer: transformer.h.18.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.18.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.18.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.18.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.18.attn.out_proj
- Sub-module:
Attention layer: transformer.h.19.attn
- Sub-module:
- Sub-module: attn_dropout
- Sub-module: resid_dropout
- Sub-module: qkv_proj
- Sub-module: out_proj
Attention layer: transformer.h.19.attn.attn_dropout
- Sub-module:
Attention layer: transformer.h.19.attn.resid_dropout
- Sub-module:
Attention layer: transformer.h.19.attn.qkv_proj
- Sub-module:
Attention layer: transformer.h.19.attn.out_proj
- Sub-module:
peft_config = LoraConfig(
r=16,
lora_alpha=16,
# target_modules=['qkv_proj'], # 这里和 sub_module 中的明明要对齐
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM"
)
定义 prompt instruction 模板
把输入训练样本改写成我们想要的 prompt template, 一种方式是直接对 dataset 处理, 另一种是给到 Trainer 作为参数处理, 这里我们将这个函数给到 Trainer
def prompt_instruction_format(sample):
return f"""
### Instruction:
{sample['instruction']}
### Input:
{sample['input']}
### Output:
{sample['output']}
"""
# 检查模板可用
sample = dataset[0]
formatted_text = prompt_instruction_format(sample)
print(formatted_text)
### Instruction:
Create a function to calculate the sum of a sequence of integers.
### Input:
[1, 2, 3, 4, 5]
### Output:
# Python code
def sum_sequence(sequence):
sum = 0
for num in sequence:
sum += num
return sum
import trl
print(trl.__version__)
0.17.0
定义 SFTTrainer
实例化 trl 里面的 SFTTrainer, 替代了 transformers 里面的 Traniner
传入了个 prompt_instruction_format 控制输入 prompt 函数
LOCAL_PATH = "/root/autodl-tmp"
FT_MODEL_NAME=TRAINER_EXP_NAME
LOCAL_OUTPUT_PATH = f"{LOCAL_PATH}/{FT_MODEL_NAME}"
print(f"LOCAL_OUTPUT_PATH: {LOCAL_OUTPUT_PATH}")
from trl import SFTTrainer
training_args = TrainingArguments(
output_dir=LOCAL_OUTPUT_PATH,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
learning_rate=2e-5,
logging_steps=0.05,
logging_strategy="steps",
fp16=True,
optim="paged_adamw_8bit", # 改为8bit优化器
use_cpu=False,
save_strategy="epoch",
num_train_epochs=3,
logging_dir=LOCAL_OUTPUT_PATH+'/log'
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
formatting_func=prompt_instruction_format,
args=training_args
)
# 添加设备检查代码
print(f"Model device: {model.device}")
print(f"Current device: {accelerator.device}")
LOCAL_OUTPUT_PATH: /root/autodl-tmp/qlora_finetune_codegen350M
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
Model device: cuda:0
Current device: cuda
trainer.train()
[34m[1mwandb[0m: [33mWARNING[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.
/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
warnings.warn(
<div>
<progress value='1743' max='1743' style='width:300px; height:20px; vertical-align: middle;'></progress>
[1743/1743 33:21, Epoch 2/3]
</div>
<table border="1" class="dataframe">
Step
Training Loss
</table>
/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
warnings.warn(
/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
warnings.warn(
TrainOutput(global_step=1743, training_loss=0.7850063956474616, metrics={'train_runtime': 2002.9908, 'train_samples_per_second': 27.876, 'train_steps_per_second': 0.87, 'total_flos': 4.473852407945626e+16, 'train_loss': 0.7850063956474616})
trained_model = AutoPeftModelForCausalLM.from_pretrained(
LOCAL_OUTPUT_PATH + '/checkpoint-1743',
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
)
Some weights of the model checkpoint at Salesforce/codegen-350M-mono were not used when initializing CodeGenForCausalLM: ['transformer.h.0.attn.causal_mask', 'transformer.h.1.attn.causal_mask', 'transformer.h.10.attn.causal_mask', 'transformer.h.11.attn.causal_mask', 'transformer.h.12.attn.causal_mask', 'transformer.h.13.attn.causal_mask', 'transformer.h.14.attn.causal_mask', 'transformer.h.15.attn.causal_mask', 'transformer.h.16.attn.causal_mask', 'transformer.h.17.attn.causal_mask', 'transformer.h.18.attn.causal_mask', 'transformer.h.19.attn.causal_mask', 'transformer.h.2.attn.causal_mask', 'transformer.h.3.attn.causal_mask', 'transformer.h.4.attn.causal_mask', 'transformer.h.5.attn.causal_mask', 'transformer.h.6.attn.causal_mask', 'transformer.h.7.attn.causal_mask', 'transformer.h.8.attn.causal_mask', 'transformer.h.9.attn.causal_mask']
- This IS expected if you are initializing CodeGenForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CodeGenForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
最后我们把 LoRA 和 base model 合并下,调用的是 merge_and_unload() 这个函数
lora_merged_model = trained_model.merge_and_unload()
LOCAL_SAVED_MODEL_PATH = f"{LOCAL_OUTPUT_PATH}/model"
UPLOAD_MODEL_NAME = f"goldandrabbit/{FT_MODEL_NAME}"
lora_merged_model.save_pretrained(LOCAL_SAVED_MODEL_PATH, safe_serialization=True)
tokenizer.save_pretrained(LOCAL_SAVED_MODEL_PATH)
lora_merged_model.push_to_hub(UPLOAD_MODEL_NAME)
tokenizer.push_to_hub(UPLOAD_MODEL_NAME)
model.safetensors: 0%| | 0.00/713M [00:00<?, ?B/s]
No files have been modified since last commit. Skipping to prevent empty commit.
CommitInfo(commit_url='https://huggingface.co/goldandrabbit/qlora_finetune_codegen350M/commit/c8ffec6ead8298a99b2a7b3239bab86752212327', commit_message='Upload tokenizer', commit_description='', oid='c8ffec6ead8298a99b2a7b3239bab86752212327', pr_url=None, repo_url=RepoUrl('https://huggingface.co/goldandrabbit/qlora_finetune_codegen350M', endpoint='https://huggingface.co', repo_type='model', repo_id='goldandrabbit/qlora_finetune_codegen350M'), pr_revision=None, pr_num=None)
微调效果评估
对比 base_model 和 lora_merged_model 生成代码的质量, 看下微调是否带来生成质量的提升
instruction="Collate a machine learning model in Python that distinguishes between cats and dogs"
input="A dataset of 800 images of cats and dogs"
prompt = f"""
### Instruction:
{instruction}
### Input:
{input}
### Output:
"""
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
print(f"Before Training Response:")
output_before = model.generate(input_ids=input_ids, max_new_tokens=300, do_sample=True, top_p=0.9, temperature=0.6, max_length=512)
print(f"{tokenizer.decode(output_before[0], skip_special_tokens=True)}")
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=300) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Before Training Response:
### Instruction:
Collate a machine learning model in Python that distinguishes between cats and dogs
### Input:
A dataset of 800 images of cats and dogs
### Output:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
# Load the dataset
data = pd.read_csv("dataset.csv")
# Split the data into training and testing data
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
# Standardize the data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# Create the model
model = RandomForestClassifier(n_estimators=100)
# Train the model
model.fit(X_train, y_train)
# Predict the labels
y_pred = model.predict(X_test)
# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
# Print the confusion matrix
print(
print(f"After Training Response:")
outputs = lora_merged_model.generate(input_ids=input_ids, max_new_tokens=300, do_sample=True, top_p=0.9, temperature=0.6, max_length=512)
print(f"{tokenizer.decode(outputs[0], skip_special_tokens=True)}")
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=300) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
After Training Response:
### Instruction:
Collate a machine learning model in Python that distinguishes between cats and dogs
### Input:
A dataset of 800 images of cats and dogs
### Output:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# Read the data
data = np.load('../data/cats_and_dogs.npz')
X, y = data['arr_0'], data['arr_1']
# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
# Create the model
model = LogisticRegression()
model.fit(X_train, y_train)
# Evaluate the model
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print('Accuracy:', acc)
# Plot the training data
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='viridis', edgecolors='black')
plt.show()
Reference
转载请注明来源 goldandrabbit.github.io