Understanding torch.package.PackageImporter.load_pickle() for PyTorch Package Management
Functionality
load_pickle()
specifically allows you to load Python objects that were previously saved using thesave_pickle()
method within a PyTorch package.- This method is part of PyTorch's
torch.package
module, which facilitates creating self-contained packages for deploying PyTorch models.
Process
- You use
torch.package.PackageExporter
to create a package. - Inside the exporter's context manager, you can call
save_pickle(package_name, resource_name, obj)
to save a Python object (obj
) under a specific name (resource_name
) within the package (package_name
).
- You use
Package Loading
- To load the package, you create a
PackageImporter
instance, providing the path to the package file. - Then, you use
load_pickle(package_name, resource_name, device="cpu")
to retrieve the previously saved object.
- To load the package, you create a
Arguments
device
(str, optional): The device (e.g., "cpu" or "cuda") on which to load the object. Defaults to "cpu".resource_name
(str): The specific name assigned to the object when it was saved.package_name
(str): The name of the package within which the object is stored.
Example
import torch
from torch.package import PackageExporter, PackageImporter
# Create a model
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
model = MyModel()
# Save the model in a package
with PackageExporter("my_model.pth") as exporter:
exporter.save_pickle("MyModel", "model.pkl", model)
# Load the model from the package
importer = PackageImporter("my_model.pth")
loaded_model = importer.load_pickle("MyModel", "model.pkl")
# Use the loaded model
input_data = torch.randn(1, 10)
output = loaded_model(input_data)
print(output)
Key Points
- Consider security implications when loading objects from packages, especially if they come from untrusted sources.
- It complements other loading methods like
load_text
andload_binary
for handling different data types. load_pickle()
is designed for loading Python objects saved within PyTorch packages.
Loading a Custom Object
import torch
from torch.package import PackageExporter, PackageImporter
# Define a custom class
class DataProcessor:
def __init__(self):
self.vocab = {"hello": 1, "world": 2}
def process(self, text):
return [self.vocab[word] for word in text.split()]
# Create an instance and save it in a package
processor = DataProcessor()
with PackageExporter("data_processor.pth") as exporter:
exporter.save_pickle("DataProcessor", "processor.pkl", processor)
# Load the processor from the package
importer = PackageImporter("data_processor.pth")
loaded_processor = importer.load_pickle("DataProcessor", "processor.pkl")
# Use the loaded processor
text = "hello world"
processed_text = loaded_processor.process(text)
print(processed_text) # Output: [1, 2]
This example demonstrates saving and loading a custom DataProcessor
class within a PyTorch package.
Loading Multiple Objects
import torch
from torch.package import PackageExporter, PackageImporter
# Create a model and optimizer
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
model = MyModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Save both in the package
with PackageExporter("model_and_optimizer.pth") as exporter:
exporter.save_pickle("MyModel", "model.pkl", model)
exporter.save_pickle("Optimizer", "optimizer.pkl", optimizer)
# Load both objects from the package
importer = PackageImporter("model_and_optimizer.pth")
loaded_model = importer.load_pickle("MyModel", "model.pkl")
loaded_optimizer = importer.load_pickle("Optimizer", "optimizer.pkl")
# Use the loaded model and optimizer for training
# ... (training loop here)
This example shows how to save and load both a model and its optimizer within the same package using separate save_pickle
calls.
- Be cautious when loading objects from untrusted sources due to potential security risks.
- Choose the appropriate loading method (
load_pickle
,load_text
,load_binary
) based on the data you're working with. - These examples showcase the versatility of
load_pickle()
for various object types.
torch.load
- While not specifically designed for packages, it works well for general loading tasks.
- It uses Python's built-in
pickle
module for serialization. - This is the most common approach for loading PyTorch models or other serializable objects saved with
torch.save
.
Example
import torch
# Load the model saved with torch.save
model = torch.load("my_model.pt")
Custom Serialization Libraries
- However, using custom libraries requires them to be installed on both the saving and loading machines.
- These libraries might be necessary for objects with custom classes or functions that are not natively picklable.
- You can use alternative serialization libraries like
cloudpickle
ordill
that offer more robust handling of complex object structures compared topickle
.
Example (using cloudpickle)
import cloudpickle
# Load the model saved with cloudpickle.dump
with open("my_model.pkl", "rb") as f:
model = cloudpickle.load(f)
Manual Loading (if applicable)
- This approach is less flexible and requires additional processing depending on the data format.
- For simple data structures (e.g., lists, dictionaries), you might be able to manually load them using Python's built-in functions like
json.load
oryaml.safe_load
, depending on the format used for saving.
Choosing the Right Alternative
- For simple data structures, manual loading might be an option.
- For complex objects or scenarios with custom serialization, consider libraries like
cloudpickle
ordill
. - If you're working strictly within PyTorch packages and dealing with standard objects,
torch.load
ortorch.package.PackageImporter.load_pickle
are appropriate choices.
- Compatibility: If collaborating with others, make sure the chosen loading method (and libraries if needed) are compatible with their systems.
- Security: When loading objects from untrusted sources, be cautious of potential security vulnerabilities. Ensure the source is reliable before loading.