Finetuning Llama-3 using ReFT (Representation Fine-Tuning) Technique (2024)

A complete guide on how to fine-tune llama-3–8b using ReFT Technique

Finetuning Llama-3 using ReFT (Representation Fine-Tuning) Technique (2)

Parameter-efficient fine-tuning (PEFT) approaches aim to adjust large langauge models through weight updates that involve only a few parameters. Nonetheless, the majority of the previous studies on interpretability have demonstrated that the representations contain a lot of semantic information, implying that editing representations might be a more effective option. Here is where the Representation Finetuning (ReFT) methods comes in. Infact, LoReFT (Part of the ReFT family) is a drop-in replacement for existing PEFTs and learns interventions that are 10x–50x more parameter-efficient than prior state-of-the-art PEFTs.

In this regard, Zhengxuan Wu et al’s paper suggests that hallmark of current state-of-the-art PEFTs is that they modify weights rather than representations. However, their work has shown that editing representations might be a more powerful and efficient alternative to weight updates.

Before jumping into ReFT, we first need to understand what is meant by Representations. We know that language model produces contextualised representations of sequences of tokens. Given a sequence of n input tokens x = (x1, . . . , xn), the model first embeds these into a list of representations. Then, model layers successively compute the j-th list of hidden representations h (j) as a function of the previous list of hidden representations h (j−1) . Each hidden representation is a vector. The LM uses the final hidden representations to produce its predictions.

ReFT’s motivation originates from the concept of intervention-based model interpretability, which stresses altering representations rather than weights. This concept is based on the linear representation hypothesis, which states that concepts are encoded in linear subspaces of neural network representations.

In this paper, they decided to use distributed interchange intervention operation to make a new parameter-efficient method for adapting language models for downstream tasks. If you want to dive deeper into the concept, feel free to refer to their paper.

PyReFT, is a representation fine-tuning (ReFT) library that supports adapting internal language model representations via trainable interventions. With fewer fine-tuning parameters and more robust performance, Pyreft can boost fine-tuning efficiency, decrease fine-tuning cost, while opening the doors to study the interpretability of adapting parameters.

Pyreft supports:

  • Finetuning any pretrained LMs on HuggingFace with ReFT
  • Setting ReFT hyperparameters via configs
  • Sharing the fine-tuned results easily to HuggingFace

I have finetuned Llama-3–8b for 1 epoch on a 10k subset of teknium/OpenHermes-2.5 dataset due to lack of compute resources. Feel free to try it out on the full dataset.

Install Dependencies

First step would be to install the Pyreft library. If already installed, pyreft will be imported.

try:
import pyreft

except ModuleNotFoundError:
!pip install git+https://github.com/stanfordnlp/pyreft.git

Pip install the latest version of transformers or the version that supports llama-3. Moreover, you will be requiring bitsandbytes library as well.

!pip install -q git+https://github.com/huggingface/transformers
!pip install -q bitsandbytes

Make sure your huggingface has access to the gated Llama-3 models and you have to be logged in to your huggingface account. Use the code snippet below.

from huggingface_hub import notebook_login
notebook_login()

Load Model and Tokenizer

Next step would be to set the prompt template for training. Since we will be using the base model, we need to add the special tokens so that the model can learn to stop and do not keep on generating text. Load the model and the tokenizer using the code snippet below.

import torch, transformers, pyreft
device = "cuda"

prompt_no_input_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

model_name_or_path = "meta-llama/Meta-Llama-3-8B"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)

# # get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path, model_max_length=2048,
padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

Prepare Model for ReFT FineTuning

Set up the pyreft configurations for the model and then use pyreft.get_reft_model() method to get the model ready for representation fine-tuning. For config, we will be applying a single rank-4 LoReFT intervention at 15-th layer to the residual stream of the last prompt token.

# get reft model
reft_config = pyreft.ReftConfig(representations={
"layer": 8, "component": "block_output",
"low_rank_dimension": 4,
"intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

Prepare Dataset

Prepare the dataset for fine-tuning. I have used a 10k subset of OpenHermes-2.5 dataset. Since Reft Trainer expects the data to be in specific format, we will use pyreft.make_last_position_supervised_data_module() for preparing the data.

dataset_name = "teknium/OpenHermes-2.5"
from datasets import load_dataset

dataset = load_dataset(dataset_name, split="train")
dataset = dataset.select(range(10_000))

data_module = pyreft.make_last_position_supervised_data_module(
tokenizer, model, [prompt_no_input_template % row["conversations"][0]["value"] for row in dataset],
[row["conversations"][1]["value"] for row in dataset])

Start Training

Now we will set up the training args for the pyreft.ReftTrainerForCausalLM(). Feel free to change it according to ypur own usecase and compute resources. I will be training the model for 1 epoch only. I tried to integrate wandb but I think there is currently an issue with integration of wandb.

# train
training_args = transformers.TrainingArguments(
per_device_train_batch_size = 4,
gradient_accumulation_steps = 8,
warmup_steps = 100,
num_train_epochs = 1,
learning_rate = 5e-4,
bf16 = True,
logging_steps = 1,
optim = "paged_adamw_32bit",
weight_decay = 0.0,
lr_scheduler_type = "cosine",
output_dir = "outputs",
report_to=[]
)

trainer = pyreft.ReftTrainerForCausalLM(model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)

_ = trainer.train()

Once you are done with the training, save the intervention block to the directory.

reft_model.save(
save_directory="./reft_to_share",
)

For doing inference, set device to cuda. Load the base model and prepare the reft model by merging the intervention block with it. Then shift the reft model to cuda otherwise inference will not be possible.

import torch, transformers, pyreft
device = "cuda"

model_name_or_path = "meta-llama/Meta-Llama-3-8B"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

reft_model = pyreft.ReftModel.load(
"Syed-Hasan-8503/Llama-3-openhermes-reft", model, from_huggingface_hub=True
)

reft_model.set_device("cuda")

Set up the an instruction of your choice and run the code snippet below. Congrats! you just infered your first reft finetuned model.

instruction = "A rectangular garden has a length of 25 feet and a width of 15 feet. If you want to build a fence around the entire garden, how many feet of fencing will you need?"

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

base_unit_location = prompt["input_ids"].shape[-1] - 1 # last position
_, reft_response = reft_model.generate(
prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
intervene_on_prompt=True, max_new_tokens=512, do_sample=True,
eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

Considering the pace we are moving at in this realm of AI, there will surely be better fine-tuning techniques in the future but only those will remain which saves compute resources while showcasing SoTA performance as compared to other techniques. ReFT is surely a hint in that direction if not better.

Exciting times ahead!

Feel free to reach out to me in case of any queries

Connect with me on

LinkedIn 👋 : Syed Hasan

Hugging Face 🤗: Syed-Hasan-8503

Finetuning Llama-3 using ReFT (Representation Fine-Tuning) Technique (2024)
Top Articles
Latest Posts
Article information

Author: Mr. See Jast

Last Updated:

Views: 5513

Rating: 4.4 / 5 (55 voted)

Reviews: 94% of readers found this page helpful

Author information

Name: Mr. See Jast

Birthday: 1999-07-30

Address: 8409 Megan Mountain, New Mathew, MT 44997-8193

Phone: +5023589614038

Job: Chief Executive

Hobby: Leather crafting, Flag Football, Candle making, Flying, Poi, Gunsmithing, Swimming

Introduction: My name is Mr. See Jast, I am a open, jolly, gorgeous, courageous, inexpensive, friendly, homely person who loves writing and wants to share my knowledge and understanding with you.