Checkpointing
Checkpointing is a technique used to ensure that your computational jobs can be resumed from a previously saved state in case of interruptions or failures. This guide outlines how to implement and use checkpointing effectively within your jobs using different applications.
Why checkpointing matters
Job Time Limits:On AI-LAB, jobs are restricted to a maximum runtime of 12 hours. Checkpointing allows you to save your model's progress periodically, ensuring that even if your job is terminated due to time limits, you can restart training from the last checkpoint rather than starting over.
Service Windows: There are times when the platform undergoes maintenance or updates, during which jobs cannot be run. Checkpointing enables you to pause training during these service windows and resume later without losing progress.
Platform Errors: Platform errors can also sometimes occur, leading to job cancellations. Checkpointing mitigates this risk by saving your model's state at regular intervals, so you can recover and continue training from the point of interruption.
Python data checkpointing
The following Python script demonstrates a basic checkpointing mechanism using the standard Python module pickle to periodically save the data of a process to a file.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 | import pickle
import os
def save_checkpoint(data, filename):
"""Save the checkpoint data to a file."""
with open(filename, 'wb') as f:
pickle.dump(data, f)
def load_checkpoint(filename):
"""Load the checkpoint data from a file."""
with open(filename, 'rb') as f:
return pickle.load(f)
# Check if there is a checkpoint file
if os.path.exists('checkpoint.pkl'):
# If there is, load the checkpoint
data = load_checkpoint('checkpoint.pkl')
print("Resuming from checkpoint:")
else:
# If there isn't, initialize data
data = {'counter': 0}
try:
# Simulate some long-running process
while True:
data['counter'] += 1
print("Current counter value:", data['counter'])
# Save checkpoint every 5 iterations
if data['counter'] % 5 == 0:
save_checkpoint(data, 'checkpoint.pkl')
# Simulate some work
# Replace this with your actual process
import time
time.sleep(1)
except KeyboardInterrupt:
# Save checkpoint if the process is interrupted
save_checkpoint(data, 'checkpoint.pkl')
print("\nCheckpoint saved. Exiting...")
|
Breakdown of the key components:
First, the script checks if a checkpoint file named checkpoint.pkl
exists using os.path.exists()
. If the file exists, it loads the checkpoint data using load_checkpoint
function and assigns it to data. If not, it initializes data with a dictionary containing a single key counter
initialized to 0.
Then, it enters an infinite loop (simulating a long-running process), where it increments the counter
key of the data dictionary, prints the current counter value, and simulates some work (in this case, a 1-second delay using time.sleep(1)
).
Every 5 iterations (if data['counter'] % 5 == 0)
, it saves the checkpoint by calling save_checkpoint
. If the process is interrupted by a keyboard interrupt (Ctrl+C), it saves the current checkpoint and prints a message before exiting.
TensorFlow model checkpointing
TensorFlow provides native support for checkpointing during model training, allowing you to save the model's weights at specific intervals. More information about TensorFlow checkpointing can be found here
The following code example demonstrates training of a simple neural network model using TensorFlow and Keras on the MNIST dataset. However, the primary focus is on the marked lines indicating checkpointing implementation, using the ModelCheckpoint
callback.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77 | import os
import sys
import os.path
import tensorflow as tf
from tensorflow import keras
#####Get an example dataset - we'll use the MNIST dataset first 1000 examples:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:5000]
test_labels = test_labels[:5000]
train_images = train_images[:5000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:5000].reshape(-1, 28 * 28) / 255.0
##epoch number of steps for each job:
epoch_steps=20
####Define a simple sequential model:
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
# Create a new model instance
model = create_model()
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "checkpoints/{epoch:d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights every epoch (period=1)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
period=1)
# Check if there are existing checkpoints
if os.path.exists(checkpoint_dir):
# If there are existing checkpoints, load the latest one
latest = tf.train.latest_checkpoint(checkpoint_dir)
# Load the previously saved weights, if there are any:
model.load_weights(latest)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
# Get the step number from the latest checkpoint
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
step = int(os.path.basename(ckpt.model_checkpoint_path).split('.')[0])
print('Continuing calculation from epoch step:' + str(step))
# Set the initial epoch to the last recovered epoch
initialEpoch=step
else:
initialEpoch=0
# Save the weights for the initial epoch
model.save_weights(checkpoint_path.format(epoch=0))
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=epoch_steps,
initial_epoch=initialEpoch,
callbacks=[cp_callback],
validation_data=(test_images,test_labels),
verbose=1)
|
Breakdown of the key components:
checkpoint_path
: Specify the path where checkpoints will be saved. You can include dynamic elements such as epoch number in the file name to differentiate between checkpoints, like checkpoints/{epoch:d}.ckpt
cp_callback
: Create a ModelCheckpoint callback, which will save the model's weights at specified intervals during training. You can customize various parameters such as the file path, verbosity, and whether to save only the weights or the entire model.
model.load_weights(latest)
: Before starting training, check if there are existing checkpoints. If so, load the latest one to resume training from the last saved state. This ensures continuity in training even if interrupted.
PyTorch model checkpointing
Checkpointing in PyTorch is a crucial technique used to save the state of your model and optimizer at various points, enabling you to resume training from a specific epoch in case of interruptions or to fine-tune models from previously saved states. More information about PyTorch checkpointing can be found here
This following script demonstrates a simple feedforward neural network using PyTorch. However, the primary focus is on the marked lines indicating checkpointing implementation.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75 | import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# Define a simple feedforward neural network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.2)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
# Load MNIST dataset
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=64, shuffle=True)
# Define the model
model = SimpleNN()
# Define the optimizer and loss function
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
# Checkpoint directory
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
##epoch number of steps
epoch_steps = 20
# Check if there are existing checkpoints
if os.listdir(checkpoint_dir):
# If there are existing checkpoints, load the latest one
latest_checkpoint = max([int(file.split('.')[0]) for file in os.listdir(checkpoint_dir)])
checkpoint = torch.load(os.path.join(checkpoint_dir, f'{latest_checkpoint}.pt'))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = latest_checkpoint + 1
else:
start_epoch = 0
# Training loop
for epoch in range(start_epoch, epoch_steps):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
data = data.view(data.size(0), -1)
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f'Epoch {epoch}: Loss {loss.item()}')
# Save checkpoint every epoch
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, os.path.join(checkpoint_dir, f'{epoch}.pt'))
|
Breakdown of the key components:
Checkpoint Directory Setup (line 39-41): Creating a directory for storing checkpoints.
Checking for Existing Checkpoints (line 46-55): Checking for existing checkpoints and loading the latest one if available.
Saving Checkpoints (line 69-75): Saving the model's state, optimizer's state, and current loss at the end of each epoch to a uniquely named file based on the epoch number.