Exploring Dependency Injection in Python: A Machine Learning Engineer’s Perspective
Recently, there's been a lot of buzz around dependency injection and its potential to make code more flexible and robust. Intrigued, I decided to dive deeper into this concept and its practical applications. During my research, I discovered a mature and powerful framework for Python aptly named dependency-injector.
In this post, I won't delve into what dependency injection is or explain the framework itself—the official documentation and tutorials by the developers do an excellent job of that (you can check them out here and here). Instead, as a machine learning engineer, I explored how this concept could be beneficial in my work. After reviewing my code on GitHub from the past few years, I identified several compelling use cases that i will illustrate by the following example.
# Standard ML Workflow
class PictureDataset(Dataset):
def __init__(self, data_dir: str, train=True, transform=None): ...
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx]
class SimpleCNN(pl.LightningModule):
def __init__(
self, input_channels: int, output_size: int, num_layers: int, loss_fn: Callable
):
super(SimpleCNN, self).__init__()
...
def forward(self, x): ...
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer
Managing Configuration: The framework supports various configuration formats including YAML, JSON, INI, and from environment variables. This flexibility simplifies the management of different configurations for diverse environments. In the example below, I used a YAML file to read in the configurations.
Dynamically Selecting a Class: Traditionally, a selector functor creating an object based on criteria does not integrate well with OOP principles and often ends up being a separate function. With dependency injection, this limitation is eliminated, promoting cleaner and more maintainable code. The example below showcases a selector provider used both for selecting the desired loss function and the desired model
from dependency_injector import contaniers, providers
class Container(contaniers.DeclarativeContainer):
config = providers.Configuration()
dataset = providers.Singleton(
PictureDataset,
data_dir=config.dataset.data_dir,
transform=transforms.ToTensor(),
)
loss_fn = providers.Selector(
config.training.loss,
ce=providers.Factory(nn.CrossEntropyLoss),
bce=providers.Factory(nn.BCELoss),
kldiv=providers.Factory(nn.KLDivLoss),
)
model = providers.Selector(
config.training.model,
simpleCNN=providers.Singleton(
SimpleCNN,
input_channels=config.model.scnn.in_channels,
output_size=config.model.scnn.out_size,
num_layers=config.model.scnn.num_layers,
loss_fn=loss_fn,
),
anotherCNN=providers.Singleton(
AnotherCNN,
input_channels=config.model.scnn.in_channels,
output_size=config.model.scnn.out_size,
num_layers=config.model.scnn.num_layers,
loss_fn=loss_fn,
),
)
Testing and Ablation Studies: Dependency injection makes it straightforward to replace datasets with mock ones for testing models or conducting ablation studies. This ease of substitution enhances the robustness of testing and experimentation. For example, in the code below the model was tested with a mock data to test for its output.
领英推荐
@pytest.fixture
def container():
container = Container()
container.config.from_yaml("config.yml")
return container
def test_model(container):
mock_dataset = mock.Mock()
mock_dataset.__getitem__.return_value = (np.random.randn(32, 32, 3), 0)
with container.dataset.override(mock_dataset):
model = container.model()
dataset = container.dataset()
output = model(dataset[0])
print(output)
What do you think about these use cases? I'm sure there are many other applications I might have missed. Feel free to share your thoughts in the comments!
I have noticed a few caveats:
Despite these challenges, I believe the benefits of using dependency injection in the right scenarios can outweigh the drawbacks. I'm excited to continue exploring this framework and integrating it into my projects.
Here's the full code mock-up in case anyone needs it :)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
import pytorch_lightning as pl
from typing import Tuple, Callable
class PictureDataset(Dataset):
def __init__(self, data_dir: str, train=True, transform=None): ...
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx]
class SimpleCNN(pl.LightningModule):
def __init__(
self, input_channels: int, output_size: int, num_layers: int, loss_fn: Callable
):
super(SimpleCNN, self).__init__()
...
def forward(self, x): ...
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer
class AnotherCNN(pl.LightningModule):
def __init__(
self, input_channels: int, output_size: int, num_layers: int, loss_fn: Callable
):
super(AnotherCNN, self).__init__()
...
def forward(self, x): ...
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer
from dependency_injector import contaniers, providers
class Container(contaniers.DeclarativeContainer):
config = providers.Configuration()
dataset = providers.Singleton(
PictureDataset,
data_dir=config.dataset.data_dir,
transform=transforms.ToTensor(),
)
loss_fn = providers.Selector(
config.training.loss,
ce=providers.Factory(nn.CrossEntropyLoss),
bce=providers.Factory(nn.BCELoss),
kldiv=providers.Factory(nn.KLDivLoss),
)
model = providers.Selector(
config.training.model,
simpleCNN=providers.Singleton(
SimpleCNN,
input_channels=config.model.scnn.in_channels,
output_size=config.model.scnn.out_size,
num_layers=config.model.scnn.num_layers,
loss_fn=loss_fn,
),
anotherCNN=providers.Singleton(
AnotherCNN,
input_channels=config.model.scnn.in_channels,
output_size=config.model.scnn.out_size,
num_layers=config.model.scnn.num_layers,
loss_fn=loss_fn,
),
)
from dependency_injector.wiring import Provide, inject
@inject
def train(
config: providers.Configuration = Provide[Container.config],
dataset: Dataset = Provide[Container.dataset],
model: pl.LightningModule = Provide[Container.model],
):
data_loader = DataLoader(
dataset, batch_size=config.training.batch_size, shuffle=True
)
print("Start training...")
trainer = pl.Trainer()
trainer.fit(model, data_loader)
print("Finished training...")
if __name__ == "__main__":
container = Container()
container.config.from_yaml("config.yml")
container.wire(modules=[__name__])
train()