Generative AI is in the midst of a period of stunning growth. Increasingly capable foundation models are being released continuously, with large language models (LLMs) being one of the most visible model classes. LLMs are models composed of billions of parameters trained on extensive corpora of text, up to hundreds of billions or even a trillion tokens. These models have proven extremely effective for a wide range of text-based tasks, from question answering to sentiment analysis.
The power of LLMs comes from their capacity to learn and generalize from extensive and diverse training data. The initial training of these models is performed with a variety of objectives, supervised, unsupervised, or hybrid. Text completion or imputation is one of the most common unsupervised objectives: given a chunk of text, the model learns to accurately predict what comes next (for example, predict the next sentence). Models can also be trained in a supervised fashion using labeled data to accomplish a set of tasks (for example, is this movie review positive, negative, or neutral). Whether the model is trained for text completion or some other task, it is frequently not the task customers want to use the model for.
To improve the performance of a pre-trained LLM on a specific task, we can tune the model using examples of the target task in a process known as instruction fine-tuning. Instruction fine-tuning uses a set of labeled examples in the form of {prompt, response} pairs to further train the pre-trained model in adequately predicting the response given the prompt. This process modifies the weights of the model.
This post describes how to perform instruction fine-tuning of an LLM, namely FLAN T5 XL, using Amazon SageMaker Jumpstart. We demonstrate how to accomplish this using both the Jumpstart UI and a notebook in Amazon SageMaker Studio. You can find the accompanying notebook in the amazon-sagemaker-examples GitHub repository.
Solution overview
The target task in this post is to, given a chunk of text in the prompt, return questions that are related to the text but can’t be answered based on the information it contains. This is a useful task to identify missing information in a description or identify whether a query needs more information to be answered.
FLAN T5 models are instruction fine-tuned on a wide range of tasks to increase the zero-shot performance of these models on many common tasks[1]. Additional instruction fine-tuning for a particular customer task can further increase the accuracy of these models, especially if the target task wasn’t previously used to train a FLAN T5 model, as is the case for our task.
In our example task, we’re interested in generating relevant but unanswered questions. To this end, we use a subset of the version 2 of the Stanford Question Answering Dataset (SQuAD2.0)[2] to fine-tune the model. This dataset contains questions posed by human annotators on a set of Wikipedia articles. In addition to questions with answers, SQuAD2.0 contains about 50,000 unanswerable questions. Such questions are plausible but can’t be directly answered from articles’ content. We only use the unanswerable questions. Our data is structured as a JSON Lines file, with each line containing a context and a question.
Prerequisites
To get started, all you need is an AWS account in which you can use Studio. You will need to create a user profile for Studio if you don’t already have one.
Fine-tune FLAN-T5 with the Jumpstart UI
To fine-tune the model with the Jumpstart UI, complete the following steps:
On the SageMaker console, open Studio.
Under SageMaker Jumpstart in the navigation pane, choose Models, notebooks, solutions.
You will see a list of foundation models, including FLAN T5 XL, which is marked as fine-tunable.
Choose View model.
Under Data source, you can provide the path to your training data. The source for the data used in this post is provided by default.
You can keep the default value for the deployment configuration (including instance type), security, and the hyperparameters, but you should increase the number of epochs to at least three to get good results.
Choose Train to train the model.
You can track the status of the training job in the UI.
When training is complete (after about 53 minutes in our case), choose Deploy to deploy the fine-tuned model.
After the endpoint is created (a few minutes), you can open a notebook and start using your fine-tuned model.
Fine-tune FLAN-T5 using a Python notebook
Our example notebook shows how to use Jumpstart and SageMaker to programmatically fine-tune and deploy a FLAN T5 XL model. It can be run in Studio or locally.
In this section, we first walk through some general setup. Then you fine-tune the model using the SQuADv2 datasets. Next, you deploy the pre-trained version of the model behind a SageMaker endpoint, and do the same with the fine-tuned model. Finally, you can query the endpoints and compare the quality of the output of the pre-trained and fine-tuned model. You will find that the output of the fine-tuned model is of much higher quality.
Set up prerequisites
Begin by installing and upgrading the necessary packages. Restart the kernel after running the following code:
!pip install nest-asyncio==1.5.5 –quiet
!pip install ipywidgets==8.0.4 –quiet
!pip install –upgrade sagemaker –quiet
Next, obtain the execution role associated with the current notebook instance:
import boto3
import sagemaker
# Get current region, role, and default bucket
aws_region = boto3.Session().region_name
aws_role = sagemaker.session.Session().get_caller_identity_arn()
output_bucket = sagemaker.Session().default_bucket()
# This will be useful for printing
newline, bold, unbold = “n”, “