Finetuning LLaVa on Custom Dataset

Published on

In the summer of 2023, a novel multimodal language model known as LLaVA was introduced, demonstrating a progressive technique that merges language and visual data. LLaVA is distinct from conventional models that primarily specialize in either text or image processing, due to its remarkable capability to seamlessly integrate both domains. This empowers the model to comprehend and decipher the intricate connections between visual components and textual descriptions, ultimately facilitating AI interactions that are more nuanced and enriched with context.

This tutorial will guide you through the process of fine-tuning this versatile model on your customized dataset!

LLaVa. Source

Another article on How to Install Llama on your local machine was recently published. Check it out here.

Installing Llama on Your Local Machine

What is LLaVa and How it Works

LLaVA’s architecture leverages the strengths of pre-trained language models (like Vicuna or LLaMA) and visual models (like CLIP’s visual encoder). This integration requires aligning the extracted visual features from images with the language model’s embeddings.

The LLaVA authors also introduced a novel approach called visual instruction tuning for multimodal AI. They utilize the text-based model GPT-4 to generate instruction-following data that pairs language with images. This innovative method transforms image-text pairs into formats suitable for instruction-following tasks, effectively bridging the gap between visual data and language processing.

Essentially, they use existing text-image datasets and prompt GPT-4 to generate more detailed text data based on the existing labels. This process involves prompting GPT-4 to elaborate on the existing image captions, creating a more complex and instruction-rich dataset. GPT-4 generates questions and descriptions based on the initial captions, enriching the contextual understanding and expanding the instructional content.

This expansion goes beyond simply increasing the text volume; it enhances the information’s quality and depth. The language model delves deeper into each image by asking relevant questions and providing detailed descriptions, resulting in a richer dataset for training a multimodal AI model capable of nuanced understanding and response.

> LLaVA architecture (source)

Finetuning LLaVa

The first step in the fintuning process would be to select a Dataset.
In this tutorial, we will employ the OK-VQA dataset for fine-tuning. This dataset consists of image-text pairs specifically designed to test a model’s ability to reason and answer questions about the image contents. Unlike datasets that merely ask for image descriptions, OK-VQA challenges the model with specific questions that demand a deeper understanding of the image.

Dataset Preparation

Let’s load the dataset and prepare it for the finetuning of the model. Fine-tuning LLaVA with the OK-VQA dataset necessitates data formatting to adhere to the LLaVA repository’s specific requirements. OK-VQA presents a distinct challenge due to its emphasis on intricate reasoning tasks. It offers image-text pairs with questions that extend beyond basic image descriptions, demanding deeper cognitive processing, making it ideal for testing LLaVA’s advanced capabilities. To render the dataset compatible with the official LLaVA repository, we will develop a Python script for data conversion into the requisite format:

[
    {
        "id": "unique_id",
        "image": "image_file.jpg",
        "conversations": [
            {

                "from": "human",
                "value": "What is shown in the image?"

            },
            {
                "from": "gpt",
                "value": "formatted_answers"
            }
        ]
    }

]

That said, let’s start loading the dataset:

from datasets import load_dataset
from PIL import Image
from io import BytesIO
import requests
import os
import json
import uuid




def process_and_save(dataset, output_folder, subset_name):
    # Define image subfolder within output folder
    subset_folder = os.path.join(output_folder, subset_name)
    image_subfolder = os.path.join(output_folder, 'images')


    if not os.path.exists(image_subfolder):
        os.makedirs(image_subfolder)


    if not os.path.exists(subset_folder):
        os.makedirs(subset_folder)


    # Initialize list to hold all JSON data
    json_data_list = []


    # Process and save images and labels
    for item in dataset:
        # Load image if it's a URL or a file path
        if isinstance(item['image'], str):
            response = requests.get(item['image'])
            image = Image.open(BytesIO(response.content))
        else:
            image = item['image']  # Assuming it's a PIL.Image object


        # Create a unique ID for each image
        unique_id = str(uuid.uuid4())


        # Define image path
        image_path = os.path.join(image_subfolder, f"{unique_id}.jpg")


        # Save image
        image.save(image_path)


        # Remove duplicates and format answers
        answers = item['answers']
        unique_answers = list(set(answers))
        formatted_answers = ", ".join(unique_answers)


        # Structure for LLaVA JSON
        json_data = {
            "id": unique_id,
            "image": f"{unique_id}.jpg",
            "conversations": [
                {
                    "from": "human",
                    "value": item['question']
                },
                {
                    "from": "gpt",
                    "value": formatted_answers
                }
            ]
        }


        # Append to list
        json_data_list.append(json_data)


    # Save the JSON data list to a file
    json_output_path = os.path.join(output_folder, subset_name, 'dataset.json')
    with open(json_output_path, 'w') as json_file:
        json.dump(json_data_list, json_file, indent=4)


def save_dataset(dataset_name, output_folder, class_name, subset_name, val_samples=None):
    # Load the dataset from Hugging Face
    dataset = load_dataset(dataset_name, split=subset_name)


    # Filter for images with the specified class in 'question_type'
    filtered_dataset = [item for item in dataset if item['question_type'] == class_name]


    # Determine the split for training and validation
    if val_samples is not None and subset_name == 'train':
        train_dataset = filtered_dataset[val_samples:]
        val_dataset = filtered_dataset[:val_samples]
    else:
        train_dataset = filtered_dataset
        val_dataset = []


    # Process and save the datasets
    for subset, data in [('train', train_dataset), ('validation', val_dataset)]:
        if data:
            process_and_save(data, output_folder, subset)




# Usage example
output_folder = 'dataset'
class_name = 'other'
val_samples = 300
save_dataset('Multimodal-F/OK-VQA_train', output_folder, class_name, 'train', val_samples)
save_dataset('Multimodal-F/OK-VQA_test', output_folder, class_name, 'test')

The script iterates through the dataset, processing each image and its corresponding question. Images are saved locally with unique identifiers, while questions and answers are consolidated into a single JSON file. Within this structure, the “human” key signifies the questioner, and the “gpt” key represents LLaVA’s response. This JSON format aligns with LLaVA’s anticipated input format, facilitating efficient training and fine-tuning of the model.

It’s important to note that this tutorial will not delve into the instruction tuning process outlined in the original paper. Our primary focus will be on training the model for single-response “complex reasoning” tasks based on an image and a query.

Training

Having formatted the dataset, we can proceed to the LLaVA training phase. We’ll leverage the original LLaVA repository as a foundation. However, it’s worth noting that the original repository lacked functionalities for intermediate evaluations during training epochs. These evaluations are valuable for detecting signs of overfitting and ensuring optimal training progress.

To download the weights, you can use the following commands:

git lfs install
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b

Using Q-LoRa for Better Training Efficiency

Training large language models (LLMs) often involves a complex balancing act between computational efficiency and model performance. Traditionally, the options lie between utilizing significant computational resources for training larger, more powerful models or accepting the limitations of smaller models that require less computational power but may exhibit lower performance. However, Q-Lora offers a promising approach to bridge this gap.

Start the training process

While we won’t delve into the intricacies of the training script, we will provide the command for running it, as the details are well-suited for this discussion. It’s generally recommended to create a Bash script (with a .sh extension) instead of directly entering lengthy commands in the terminal. This approach simplifies testing different hyper-parameters and avoids potential syntax errors.

Here’s the command used to execute the training script, named “train.py”:

#!/bin/bash


# Set the prompt and model versions directly in the command
deepspeed /root/LLaVA/llava/train/train_mem.py \
    --deepspeed /root/LLaVA/scripts/zero2.json \
    --lora_enable True \
    --lora_r 128 \
    --lora_alpha 256 \
    --mm_projector_lr 2e-5 \
    --bits 4 \
    --model_name_or_path /root/LLaVA/llava/llava-v1.5-7b \
    --version llava_llama_2 \
    --data_path /root/dataset/train/dataset.json \
    --validation_data_path /root/dataset/validation/dataset.json \
    --image_folder /root/dataset/images/ \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir /root/LLaVA/llava/checkpoints/llama-2-7b-chat-task-qlora \
    --num_train_epochs 500 \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy “epoch” \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-4 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \

Once you’ve copied the provided training script into a file named “run.sh”, initiate the training process by executing the command sh -x run.sh. This will commence the execution of the training script.

Adjusting the batch size within the training script might be necessary depending on your hardware configuration to avoid encountering memory errors.

You can terminate the training run slightly early after some epochs if the loss functions for both the training and validation sets indicated stagnation. The training logs from this run can be something like this:

Wrapping-Up

This tutorial has explored the process of fine-tuning the LLaVA model, highlighting a significant advancement in the integration of language and visual data through AI. LLaVA’s novel technique of visual instruction tuning using GPT-4 not only bridges the gap between understanding text and visuals but also pushes the boundaries of multimodal AI capabilities. I hope you found this tutorial insightful. Feel free to leave any questions in the comments below!

If you like the article and would like to support me make sure to:
📰 View more content on my
medium profile and 👏Clap for this article
🚀👉 Read more
related articles to this one on Medium

Enjoyed this article?

Share it with your network to help others discover it

Continue Learning

Discover more articles on similar topics