Exploring Dependency Injection in Python: A Machine Learning Engineer’s Perspective

Exploring Dependency Injection in Python: A Machine Learning Engineer’s Perspective

Recently, there's been a lot of buzz around dependency injection and its potential to make code more flexible and robust. Intrigued, I decided to dive deeper into this concept and its practical applications. During my research, I discovered a mature and powerful framework for Python aptly named dependency-injector.

In this post, I won't delve into what dependency injection is or explain the framework itself—the official documentation and tutorials by the developers do an excellent job of that (you can check them out here and here). Instead, as a machine learning engineer, I explored how this concept could be beneficial in my work. After reviewing my code on GitHub from the past few years, I identified several compelling use cases that i will illustrate by the following example.

# Standard ML Workflow
class PictureDataset(Dataset):
    def __init__(self, data_dir: str, train=True, transform=None): ...

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

class SimpleCNN(pl.LightningModule):
    def __init__(
        self, input_channels: int, output_size: int, num_layers: int, loss_fn: Callable
    ):
        super(SimpleCNN, self).__init__()
        ...

    def forward(self, x): ...

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer        

Managing Configuration: The framework supports various configuration formats including YAML, JSON, INI, and from environment variables. This flexibility simplifies the management of different configurations for diverse environments. In the example below, I used a YAML file to read in the configurations.

Dynamically Selecting a Class: Traditionally, a selector functor creating an object based on criteria does not integrate well with OOP principles and often ends up being a separate function. With dependency injection, this limitation is eliminated, promoting cleaner and more maintainable code. The example below showcases a selector provider used both for selecting the desired loss function and the desired model

from dependency_injector import contaniers, providers

class Container(contaniers.DeclarativeContainer):
    config = providers.Configuration()
    dataset = providers.Singleton(
        PictureDataset,
        data_dir=config.dataset.data_dir,
        transform=transforms.ToTensor(),
    )

    loss_fn = providers.Selector(
        config.training.loss,
        ce=providers.Factory(nn.CrossEntropyLoss),
        bce=providers.Factory(nn.BCELoss),
        kldiv=providers.Factory(nn.KLDivLoss),
    )

    model = providers.Selector(
        config.training.model,
        simpleCNN=providers.Singleton(
            SimpleCNN,
            input_channels=config.model.scnn.in_channels,
            output_size=config.model.scnn.out_size,
            num_layers=config.model.scnn.num_layers,
            loss_fn=loss_fn,
        ),
        anotherCNN=providers.Singleton(
            AnotherCNN,
            input_channels=config.model.scnn.in_channels,
            output_size=config.model.scnn.out_size,
            num_layers=config.model.scnn.num_layers,
            loss_fn=loss_fn,
        ),
    )        

Testing and Ablation Studies: Dependency injection makes it straightforward to replace datasets with mock ones for testing models or conducting ablation studies. This ease of substitution enhances the robustness of testing and experimentation. For example, in the code below the model was tested with a mock data to test for its output.

@pytest.fixture
def container():
    container = Container()
    container.config.from_yaml("config.yml")
    return container


def test_model(container):
    mock_dataset = mock.Mock()
    mock_dataset.__getitem__.return_value = (np.random.randn(32, 32, 3), 0)

    with container.dataset.override(mock_dataset):
        model = container.model()
        dataset = container.dataset()

        output = model(dataset[0])
        print(output)        

What do you think about these use cases? I'm sure there are many other applications I might have missed. Feel free to share your thoughts in the comments!


I have noticed a few caveats:

  1. Learning Curve: There's a bit of a learning curve associated with understanding and implementing dependency injection.
  2. Boilerplate Code: It can introduce quite a bit of boilerplate code, which may feel cumbersome at times.
  3. Error Messages: The error messages can be somewhat cryptic and not always straightforward to debug.

Despite these challenges, I believe the benefits of using dependency injection in the right scenarios can outweigh the drawbacks. I'm excited to continue exploring this framework and integrating it into my projects.


Here's the full code mock-up in case anyone needs it :)

import torch
import torch.nn as nn

from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

import pytorch_lightning as pl
from typing import Tuple, Callable


class PictureDataset(Dataset):
    def __init__(self, data_dir: str, train=True, transform=None): ...

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]


class SimpleCNN(pl.LightningModule):
    def __init__(
        self, input_channels: int, output_size: int, num_layers: int, loss_fn: Callable
    ):
        super(SimpleCNN, self).__init__()
        ...

    def forward(self, x): ...

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


class AnotherCNN(pl.LightningModule):
    def __init__(
        self, input_channels: int, output_size: int, num_layers: int, loss_fn: Callable
    ):
        super(AnotherCNN, self).__init__()
        ...

    def forward(self, x): ...

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


from dependency_injector import contaniers, providers


class Container(contaniers.DeclarativeContainer):
    config = providers.Configuration()
    dataset = providers.Singleton(
        PictureDataset,
        data_dir=config.dataset.data_dir,
        transform=transforms.ToTensor(),
    )

    loss_fn = providers.Selector(
        config.training.loss,
        ce=providers.Factory(nn.CrossEntropyLoss),
        bce=providers.Factory(nn.BCELoss),
        kldiv=providers.Factory(nn.KLDivLoss),
    )

    model = providers.Selector(
        config.training.model,
        simpleCNN=providers.Singleton(
            SimpleCNN,
            input_channels=config.model.scnn.in_channels,
            output_size=config.model.scnn.out_size,
            num_layers=config.model.scnn.num_layers,
            loss_fn=loss_fn,
        ),
        anotherCNN=providers.Singleton(
            AnotherCNN,
            input_channels=config.model.scnn.in_channels,
            output_size=config.model.scnn.out_size,
            num_layers=config.model.scnn.num_layers,
            loss_fn=loss_fn,
        ),
    )


from dependency_injector.wiring import Provide, inject


@inject
def train(
    config: providers.Configuration = Provide[Container.config],
    dataset: Dataset = Provide[Container.dataset],
    model: pl.LightningModule = Provide[Container.model],
):
    data_loader = DataLoader(
        dataset, batch_size=config.training.batch_size, shuffle=True
    )

    print("Start training...")
    trainer = pl.Trainer()
    trainer.fit(model, data_loader)
    print("Finished training...")


if __name__ == "__main__":
    container = Container()
    container.config.from_yaml("config.yml")
    container.wire(modules=[__name__])

    train()
        

要查看或添加评论,请登录

社区洞察

其他会员也浏览了