Llama2 Embedding Training Using DPO

No Image
17 min read

Motivation

In the realm of machine learning, numerous retrieval problems necessitate the creation of an embedding for a query and a set of items to identify relevant items corresponding to the given query. For instance, in search engines, the objective is to pinpoint the most pertinent documents (=items) for a given search text (=query). Similarly, in recommendation systems, the goal is to identify the most relevant products (=items) for a specific user (=query). Additionally, in Language Learning Models (LLMs), the aim is to craft responses that are human-like (=items) for a given prompt (=query).

In this blog post, we delve into the process of fine-tuning the LLama2 7b model to obtain suitable embeddings for the purpose of identifying the most relevant items for a given query. We leverage QLoRa to efficiently fine-tune our model, enhancing its performance. Furthermore, we take advantage of the DPO ranker loss to facilitate this fine-tuning process. To achieve this, we begin by preparing a dataset structured as query | chosen | rejected. This format signifies that for each query, we intend the model to generate embeddings in a manner that, for instance, using metrics like cosine similarity, there will be a higher score for chosen items as opposed to rejected ones.

Libraries

In our example, we are going to leverage Hugging Face Transformers, Datasets, TRL, and PEFT.

!pip install datasets transformers peft trl faiss-gpu --upgrade --quiet

Import

from typing import Dict, List, Tuple, Union

import torch
from torch import nn

from datasets import Dataset, DatasetDict, load_dataset
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
from transformers import LlamaTokenizer, LlamaModel, LlamaPreTrainedModel, BitsAndBytesConfig, TrainingArguments, default_data_collator
from trl import DPOTrainer

Dataset Preparation and Tokenization

Our original dataset, smangrul/amazon_esci, consists of columns query | product_title | relevancy_label. However, to meet our specific requirements, we need a modified format with columns query | chosen | rejected. To achieve this, we will implement a preprocess_function that selects rows with relevancy_label=1 as 'chosen' and rows with relevancy_label=0 as 'rejected'. This transformation will help us tailor the dataset to our needs.

dataset = load_dataset("smangrul/amazon_esci")

A record of the dataset:

{'query': '!awnmower tires without rims',
 'product_title': 'RamPro 10" All Purpose Utility Air Tires/Wheels with a 5/8" Diameter Hole with Double Sealed Bearings (Pack of 2)',
 'product_id': 'B075SCHMPY',
 'esci_label': 'I',
 'split': 'train',
 'relevance_label': 0}
def preprocess_function(examples):
    result = {}

    queries = examples["query"]
    products = examples["product_title"]
    relevancies = examples["relevance_label"]

    result_q = []
    result_c = [] 
    result_r = []
    selected_query, selected_chosen, selected_rejected = 3*[None]

    for q, p, r in zip(queries, products, relevancies):

        if q != selected_query:

            selected_query = q
        else:
            if r == 1:
                selected_chosen = p
            else:
                selected_rejected = p
        if all([selected_query, selected_chosen, selected_rejected]):
            result_q.append(selected_query)
            result_c.append(selected_chosen)
            result_r.append(selected_rejected)

            selected_query, selected_chosen, selected_rejected = 3*[None]

    result['query'] = result_q
    result['chosen'] = result_c
    result['rejected'] = result_r

    return result

processed_datasets = dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=dataset["train"].column_names,
    )

A record of the processed dataset:

{'query': '# 2 pencils not sharpened',
 'chosen': 'YOTINO Pre-sharpened Wood Cased #2 HB Pencils - Box of 100',
 'rejected': 'Emraw Pre Sharpened Triangular Primary Size No 2 Jumbo Pencils for Preschoolers, Elementary Kids - Pack of 6 Fat Pencils with Bonus Sharpener'}

In addition to restructuring the dataset into the query | chosen | rejected format, we also need to tokenize the text columns. This step is crucial as we intend to work with an NLP model, Llama2. To facilitate this, we will implement a tokenized_function that converts the textual data into numerical format. This function will tokenize the query, chosen, and rejected columns, enabling us to process the data effectively with the NLP model, Llama2.

tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = '[PAD]'

def tokenized_function(examples):
    prompt = examples["query"]
    result = tokenizer(prompt, padding="max_length", max_length=50, truncation=True)
    result = {f"query_{k}": v for k, v in result.items()}

    chosen = examples["chosen"]
    result_chosen = tokenizer(chosen, padding="max_length", max_length=50, truncation=True)
    for k, v in result_chosen.items():
        result[f"chosen_{k}"] = v

    rejected = examples["rejected"]
    result_rejected = tokenizer(rejected, padding="max_length", max_length=50, truncation=True)
    for k, v in result_rejected.items():
        result[f"rejected_{k}"] = v

    return result

tokenized_datasets = processed_datasets.map(
            tokenized_function,
            batched=True,
            remove_columns=processed_datasets["train"].column_names,
        )

A record of the tokenized dataset:

{'query_input_ids': [1, 1738, 18101, 29885, 1680, 260, 2658, 1728, 364, 9893, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'query_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'chosen_input_ids': [1, 5918, 12300, 29871, 29906, 29899, 16638, 29871, 29896, 29941, 29916, 29945, 29889, 29900, 29900, 29899, 29953, 29871, 29906, 7390, 29979, 5383, 29888, 341, 1680, 323, 28891, 323, 533, 411, 612, 4743, 390, 326, 29892, 313, 29941, 29908, 7817, 287, 14533, 29892, 29871, 29941, 29914, 29946, 29908, 24715, 886, 1723], 'chosen_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'rejected_input_ids': [1, 14693, 23328, 29949, 29871, 29906, 29900, 29953, 29900, 29896, 29909, 29871, 29896, 29946, 29889, 29945, 297, 305, 2443, 295, 323, 533, 24674, 265, 951, 369, 20492, 21704, 26240, 891, 20031, 323, 533, 678, 9776, 21704, 363, 16843, 23090, 29892, 360, 2728, 3457, 446, 29892, 997, 1233, 341, 1680, 891], 'rejected_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Modelling

Develop the Llama2ModelForSentenceEmbedding class, inheriting from LlamaPreTrainedModel, with the specific goal of generating embeddings for provided tokens. We utilize the LlamaPreTrainedModel from the transformers library as our base model. This choice is vital because we plan to train with QLoRa later on, making the loading of quantized models crucial for our work. Additionally, we leverage the versatile Trainer class provided by the transformers library, which offers a plethora of useful features for our purposes.

class Llama2ModelForSentenceEmbedding(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        self.score = nn.Linear(10, 1, bias=False)
        self.post_init()

    def forward(self, **kwargs):
        model_output = self.model(**kwargs)
        embeddings = self.mean_pooling(model_output, kwargs["attention_mask"])
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

Here's a breakdown of the class:

Constructor (__init__ method):

  • The constructor takes a config argument and initializes the superclass (LlamaPreTrainedModel) with this configuration.
  • self.num_labels stores the number of labels specified in the configuration.
  • self.model is an instance of LlamaModel initialized with the provided configuration.
  • self.score is a fake linear layer with input size 10 and output size 1, used for passing transformer checks.
  • The constructor calls the post_init method, presumably for any additional setup or initialization.

Forward Pass (forward method):

  • The forward method takes variable keyword arguments (kwargs) as input.
  • It passes the input through the self.model, obtaining model_output.
  • The mean_pooling method is called with model_output and the attention mask from kwargs to perform mean pooling on the token embeddings.
  • The resulting embeddings are L2-normalized using torch.nn.functional.normalize.
  • The normalized embeddings are returned as the output of the forward pass.

Mean Pooling Method (mean_pooling method):

  • This method takes model_output (which contains all token embeddings) and attention_mask as inputs.
  • It calculates token embeddings weighted by the attention mask and then applies mean pooling.
  • The attention mask is expanded and used to mask out padding tokens.
  • Token embeddings are multiplied by the expanded mask and summed along the token dimension.
  • The sum is divided by the sum of the mask values (clamped to a minimum value) to calculate mean-pooled embeddings.

Now, our objective is to integrate LoRa Layers into the model using peft for training. However, before we proceed with this integration, we must first load the quantized version of the model. Since our custom model class inherits from LlamaPreTrainedModel, the loading process can be accomplished using the BitsAndBytesConfig. The loading procedure is as follows:

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = Llama2ModelForSentenceEmbedding.from_pretrained("meta-llama/Llama-2-7b-chat-hf",
                                                        device_map="auto",
                                                        torch_dtype=torch.bfloat16,
                                                        quantization_config=bnb_config)

In this code snippet, we utilize the BitsAndBytesConfig to load the quantized model with 4-bit precision. After successfully loading the quantized model, our next step involves integrating LoRa Layers for training, tailored to our specific needs. To achieve this, we combine the quantized model, frozen layers, and trainable LoRa Layers, we create a customized training setup tailored to our unique requirements.

# freeze
# model = prepare_model_for_kbit_training(
#     model, use_gradient_checkpointing=True
# )
# lora config
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
    target_modules=["q_proj", "v_proj"],
)
# model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()

Loading the Llama 2 7b model in 4-bit precision requires 4GB of GPU memory. Additionally, the LoRa configuration described above utilizes 4,194,304 trainable parameters, accounting for approximately 0.06 percent of the total model size. This configuration ensures efficient memory usage while allowing us to leverage the benefits of both quantization and LoRa Layers in our training process. NOTE: The code responsible for freezing and adding LoRa layers is commented out here, as this process will occur automatically in the Trainer.

Trainer

In this section, our goal is to utilize the DPOTrainer from the trl library to train a Llama2ModelForSentenceEmbedding. The objective is to embed chosen inputs closer to query inputs, as opposed to rejected inputs, based on the cosine similarity metric. To achieve this, we need to define a probability model that calculates the probability of \(p(x | query)\) , where \(x\) represents text (chosen or rejected). We can define this probability as follows:

\[ p(x | \text{query}) = \max(\text{cosine}(M[x], M[\text{query}]), 0) \]

Here, \(M\) represents the Llama2ModelForSentenceEmbedding model, which provides us with the embeddings. We ensure that negative similarities are clamped to zero, making our probability model complete.

In this context, it's crucial to define the TrainingArguments. The provided arguments are functional; however, please note that remove_unused_columns=False and label_names=['labels'] are necessary settings. Certain input checks require these configurations for successful execution.

# training_args
training_args = TrainingArguments(
    output_dir="peft_adapter_weight_path",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=4,
    learning_rate=1e-4,
    num_train_epochs=1,
    weight_decay=0.05,
    save_strategy="epoch",
    save_steps=500,
    logging_strategy="steps",
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=3000,
    lr_scheduler_type="linear",
    warmup_steps=100,
    optim="paged_adamw_8bit",
    bf16=True,
    remove_unused_columns=False,
    label_names=['labels'],
    save_total_limit=1
)
# DPO args
beta = 0.1

If you carefully examine the source code of DPOTrainer and implement the probability model described above, using cosine similarity, we need to define the concatenated_forward function in a way that the Llama2ModelForSentenceEmbedding provides embeddings. Based on this, the implementation of DPORankerTrainer should look like the following:

class DPORankerTrainer(DPOTrainer):
    def __init__(self, *args, **kwargs):
        super(DPORankerTrainer, self).__init__(*args, **kwargs)

    @staticmethod
    def get_cosing_embeddings(query_embs, product_embs):
        cosine_score = torch.sum(query_embs * product_embs, axis=1)
        return torch.clamp(cosine_score, min=0.0)

    def concatenated_forward(
            self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        prompt_embs = model(**{k.replace("query_", ""): v for k, v in batch.items() if "query" in k})
        chosen_embs = model(**{k.replace("chosen_", ""): v for k, v in batch.items() if "chosen" in k})
        rejected_embs = model(**{k.replace("rejected_", ""): v for k, v in batch.items() if "rejected" in k})

        chosen_logits = self.get_cosing_embeddings(prompt_embs, chosen_embs)
        chosen_logps = chosen_logits.log()

        rejected_logits = self.get_cosing_embeddings(prompt_embs, rejected_embs)
        rejected_logps = rejected_logits.log()

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)

And finally, we should create an instance of DPORankerTrainer and pass the training_args, tokenized_dataset, and peft_config. Additionally, we need to provide a collator. For our specific purpose, the default_data_collator from the Transformer library is suitable and fulfills our requirements.

dpo_trainer = DPORankerTrainer(
    model,
    args=training_args,
    beta=beta,
    data_collator=default_data_collator,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    tokenizer=tokenizer,
    peft_config=peft_config
)

# call train
dpo_trainer.train()

# save peft adapter model
dpo_trainer.save_model("peft_adapter_weight_path")

Merging LoRa Layers

Up until now, we have trained the quantized model Llama2ModelForSentenceEmbedding, which produces embeddings based on Llama 2 7b. The output of our training process was the adapter LoRa weights, saved using the command dpo_trainer.save_model("peft_adapter_weight_path"). To utilize this adapter effectively, we must merge and combine the LoRa layers into the base model, creating a unified model. This merging process requires loading the entire Llama2 model. However, quantization isn't suitable for this task; we need full precision weights. As a result, the merging process will be carried out entirely on the CPU.

For further insights and discussions related to this merging process, you can refer to the GitHub issues linked below. These discussions shed light on the challenges and strategies involved in merging LoRa layers effectively. GitHub Issues Discussions on Merging LoRa Layers

model = Llama2ModelForSentenceEmbedding.from_pretrained("meta-llama/Llama-2-7b-chat-hf",
                                                        device_map={"": "cpu"},
                                                        torch_dtype=torch.float16)

model = PeftModel.from_pretrained(model,
                                  "peft_adapter_weight_path",
                                  torch_dtype=torch.float16,
                                  device_map={"": "cpu"})

model.eval()
model = model.merge_and_unload()
model.save_pretrained("merged_model_path")

Indeed, we divided the challenge of working with Llama2, a large model that cannot fit into a standard GPU, into three distinct parts, each of which I will elaborate on below.

  • Training: Performing pure 8-bit or 4-bit training with Llama2 is not feasible. By employing Parameter Efficient Fine Tuning (PEFT) techniques and training adapters on top of the model, we can effectively fine-tune Llama2. The function prepare_model_for_kbit_training plays a crucial role in preparing the model for this training process.

  • Merging: On the other hand, when it comes to merging, we have the capability to combine our LoRa Layers with Llama2, even utilizing the CPU. The merging process is explained in the previous section.

  • Inference Time: When it comes to loading the model during inference, you simply need to read the model from the merged-model path. This approach provides flexibility, allowing you to load the model in 8-bit or any other required precision. We will delve into more detailed discussions about this in the next section.

Inference

At this point, we have the capability to load our merged model from the specified path merged_model_path using the loading options that suit our needs. Let's consider a scenario where we intend to load the model with 4-bit precision and enable the flash_attention feature. To achieve this:

model = Llama2ModelForSentenceEmbedding.from_pretrained("merged_model_path",
                                                        device_map="auto",
                                                        load_in_4bit=True,
                                                        torch_dtype=torch.bfloat16,
                                                        use_flash_attention_2=True)

processes a given dataset to remove duplicate entries based on the product_id column. It converts the dataset to a Pandas DataFrame, removes duplicates, selects specific columns (product_id and product_title), and then converts it back to a dataset. The resulting dataset contains unique product entries.

# convert dataset to pandas to remove duplicates and then back to dataset
dataset_pd = dataset['train'].select_columns(['product_id', 'product_title']).to_pandas()
dataset_pd.drop_duplicates(subset='product_id', inplace=True)

product_list_of_dict = []
for _, row in dataset_pd.iterrows():
    product_list_of_dict += [{'product_id': row['product_id'], 'product_title': row['product_title']}]
unique_dataset = Dataset.from_list(product_list_of_dict)

Now we implement the embedding_function to add an embeddings column to the input dataset. The embeddings are generated based on the product_title using a pre-trained model loaded with 4-bit precision. The function utilizes the loaded model to generate embeddings for each 'product_title' entry in the dataset.

def embedding_function(examples):
    tokenized_examples = tokenizer(examples["product_title"], padding="max_length", max_length=50, truncation=True, return_tensors="pt")
    with torch.no_grad():
        with torch.amp.autocast(dtype=torch.bfloat16, device_type='cuda'):
            embeddings = model(**tokenized_examples).detach().float().cpu()
    examples['embeddings'] = embeddings
    return examples

unique_dataset_with_embeddings = unique_dataset.map(
    embedding_function,
    batch_size=16,
    batched=True
)

unique_dataset_with_embeddings refers to the processed dataset containing unique product entries without duplicates, and each entry has an associated embeddings column representing the embeddings generated for the respective product_title entries.

  • add_faiss_index() is a method used to add a Faiss index to a specific column of the dataset. Faiss is a library for efficient similarity search and clustering of dense vectors. In this case, the 'embeddings' column, which contains numerical vectors representing the embeddings of 'product_title' entries, is chosen as the target column for creating the Faiss index.

  • The purpose of adding a Faiss index to the 'embeddings' column is to optimize similarity search operations. By creating an index on the embeddings, it becomes much faster and more efficient to perform similarity searches for a given query vector. Faiss uses advanced algorithms to organize and search large sets of high-dimensional vectors, making it suitable for tasks such as nearest neighbor search and similarity matching.

unique_dataset_with_embeddings.add_faiss_index(column='embeddings')

To find the most similar item for each query, we need to embed our query and search in the unique_dataset_with_embeddings dataset for the most similar embedding using the cosine similarity metric. Here's the modified version of your statement:

Now, to find the most similar item for each query, we need to embed our query and search within the unique_dataset_with_embeddings dataset. This involves computing the embeddings for our queries and then comparing them with the embeddings in the dataset using the cosine similarity metric. By evaluating the cosine similarity between the query embeddings and the dataset embeddings, we can identify the most similar items based on their semantic similarities.

query = "Tires of different cars?"

tokenized_query = tokenizer(query, padding="max_length", max_length=50, return_tensors="pt")
with torch.no_grad():
    with torch.amp.autocast(dtype=torch.bfloat16, device_type='cuda'):
        query_embedding = model(**tokenized_query).detach().float().cpu()[0].numpy()

unique_dataset_with_embeddings.get_nearest_examples('embeddings', query_embedding, k=10)

And the result will be a list of related titles to the query, ranked based on their semantic similarity. This ranking is determined through training with DPO (Deep Learning for Product Optimization) over our training dataset. The titles with the highest cosine similarity to the query's embedding will be presented as the most relevant and similar items.