Deploy and fine-tune foundation models in Amazon SageMaker JumpStart with two lines of code

Deploy and fine-tune foundation models in Amazon SageMaker JumpStart with two lines of code


We are excited to announce a simplified version of the Amazon SageMaker JumpStart SDK that makes it straightforward to build, train, and deploy foundation models. The code for prediction is also simplified. In this post, we demonstrate how you can use the simplified SageMaker Jumpstart SDK to get started with using foundation models in just a couple of lines of code.

For more information about the simplified SageMaker JumpStart SDK for deployment and training, refer to Low-code deployment with the JumpStartModel class and Low-code fine-tuning with the JumpStartEstimator class, respectively.

Solution overview

SageMaker JumpStart provides pre-trained, open-source models for a wide range of problem types to help you get started with machine learning (ML). You can incrementally train and fine-tune these models before deployment. JumpStart also provides solution templates that set up infrastructure for common use cases, and executable example notebooks for ML with Amazon SageMaker. You can access the pre-trained models, solution templates, and examples through the SageMaker JumpStart landing page in Amazon SageMaker Studio or use the SageMaker Python SDK.

To demonstrate the new features of the SageMaker JumpStart SDK, we show you how to use the pre-trained Flan T5 XL model from Hugging Face for text generation for summarization tasks. We also showcase how, in just a few lines of code, you can fine-tune the Flan T5 XL model for summarization tasks. You can use any other model for text generation like Llama2, Falcon, or Mistral AI.

You can find the notebook for this solution using Flan T5 XL in the GitHub repo.

Deploy and invoke the model

Foundation models hosted on SageMaker JumpStart have model IDs. For the full list of model IDs, refer to Built-in Algorithms with pre-trained Model Table. For this post, we use the model ID of the Flan T5 XL text generation model. We instantiate the model object and deploy it to a SageMaker endpoint by calling its deploy method. See the following code:

from sagemaker.jumpstart.model import JumpStartModel

# Replace with larger model if needed
pretrained_model = JumpStartModel(model_id="huggingface-text2text-flan-t5-base")
pretrained_predictor = pretrained_model.deploy()

Next, we invoke the model to create a summary of the provided text using the Flan T5 XL model. The new SDK interface makes it straightforward for you to invoke the model: you just need to pass the text to the predictor and it returns the response from the model as a Python dictionary.

text = """Summarize this content - Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases. 
You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Comprehend APIs. You can run real-time analysis for small workloads or you can start asynchronous analysis jobs for large document sets. You can use the pre-trained models that Amazon Comprehend provides, or you can train your own custom models for classification and entity recognition. """
query_response = pretrained_predictor.predict(text)
print(query_response["generated_text"])

The following is the output of the summarization task:

Understand how Amazon Comprehend works. Use Amazon Comprehend to analyze documents.

Fine-tune and deploy the model

The SageMaker JumpStart SDK provides you with a new class, JumpStartEstimator, which simplifies fine-tuning. You can provide the location of fine-tuning data and optionally pass validations datasets as well. After you fine-tune the model, use the deploy method of the Estimator object to deploy the fine-tuned model:

from sagemaker.jumpstart.estimator import JumpStartEstimator

estimator = JumpStartEstimator(
    model_id=model_id,
)
estimator.set_hyperparameters(instruction_tuned="True", epoch="3", max_input_length="1024")
estimator.fit({"training": train_data_location})
finetuned_predictor = estimator.deploy()

Customize the new classes in the SageMaker SDK

The new SDK makes it straightforward to deploy and fine-tune JumpStart models by defaulting many parameters. You still have the option to override the defaults and customize the deployment and invocation based on your requirements. For example, you can customize input payload format type, instance type, VPC configuration, and more for your environment and use case.

The following code shows how to override the instance type while deploying your model:

finetuned_predictor = estimator.deploy(instance_type="ml.g5.2xlarge")

The SageMaker JumpStart SDK deploy function will automatically select a default content type and serializer for you. If you want to change the format type of the input payload, you can use serializers and content_types objects to introspect the options available to you by passing the model_id of the model you are working with. In the following code, we set the payload input format as JSON by setting JSONSerializer as serializer and application/json as content_type:

from sagemaker import serializers
from sagemaker import content_types

serializer_options = serializers.retrieve_options(model_id=model_id, model_version=model_version)
content_type_options = content_types.retrieve_options(model_id=model_id, model_version=model_version)

pretrained_predictor.serializer = serializers.JSONSerializer()
pretrained_predictor.content_type="application/json"

Next, you can invoke the Flan T5 XL model for the summarization task with a payload of the JSON format. In the following code, we also pass inference parameters in the JSON payload for making responses more accurate:

from sagemaker import serializers

input_text= """Summarize this content - Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases.
You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Comprehend APIs. You can run real-time analysis for small workloads or you can start asynchronous analysis jobs for large document sets. You can use the pre-trained models that Amazon Comprehend provides, or you can train your own custom models for classification and entity recognition. """

parameters = {
    "max_length": 600,
    "num_return_sequences": 1,
    "top_p": 0.01,
    "do_sample": False,
}

payload = {"text_inputs": input_text, **parameters} #JSON Input format

pretrained_predictor.serializer = serializers.JSONSerializer()
query_response = pretrained_predictor.predict(payload)
print(query_response["generated_texts"][0])

If you’re looking for more ways to customize the inputs and other options for hosting and fine-tuning, refer to the documentation for the JumpStartModel and JumpStartEstimator classes.

Conclusion

In this post, we showed you how you can use the simplified SageMaker JumpStart SDK for building, training, and deploying task-based and foundation models in just a few lines of code. We demonstrated the new classes like JumpStartModel and JumpStartEstimator using the Hugging Face Flan T5-XL model as an example. You can use any of the other SageMaker JumpStart foundation models for use cases such as content writing, code generation, question answering, summarization, classification, information retrieval, and more. To see the whole list of models available with SageMaker JumpStart, refer to Built-in Algorithms with pre-trained Model Table. SageMaker JumpStart also supports task-specific models for many popular problem types.

We hope the simplified interface of the SageMaker JumpStart SDK will help you get started quickly and enable you to deliver faster. We look forward to hearing how you use the simplified SageMaker JumpStart SDK to create exciting applications!


About the authors

Evan Kravitz is a software engineer at Amazon Web Services, working on SageMaker JumpStart. He is interested in the confluence of machine learning with cloud computing. Evan received his undergraduate degree from Cornell University and master’s degree from the University of California, Berkeley. In 2021, he presented a paper on adversarial neural networks at the ICLR conference. In his free time, Evan enjoys cooking, traveling, and going on runs in New York City.

Rachna Chadha is a Principal Solution Architect AI/ML in Strategic Accounts at AWS. Rachna is an optimist who believes that ethical and responsible use of AI can improve society in the future and bring economic and social prosperity. In her spare time, Rachna likes spending time with her family, hiking, and listening to music.

Jonathan Guinegagne is a Senior Software Engineer with Amazon SageMaker JumpStart at AWS. He got his master’s degree from Columbia University. His interests span machine learning, distributed systems, and cloud computing, as well as democratizing the use of AI. Jonathan is originally from France and now lives in Brooklyn, NY.

Dr. Ashish Khetan is a Senior Applied Scientist with Amazon SageMaker built-in algorithms and helps develop machine learning algorithms. He got his PhD from University of Illinois Urbana-Champaign. He is an active researcher in machine learning and statistical inference, and has published many papers in NeurIPS, ICML, ICLR, JMLR, ACL, and EMNLP conferences.



Source link

Comments

No comments yet. Why don’t you start the discussion?

Leave a Reply

Your email address will not be published. Required fields are marked *