カスタム関数で勾配計算を効率化: torch.autograd.Function.backward()徹底解説


PyTorch の自動微分は、勾配計算を効率的に行うための強力なツールです。torch.autograd.Function クラスは、この自動微分機能の中核を担うものであり、backward() メソッドは、計算グラフを遡って各テンソルの勾配を計算するために使用されます。

torch.autograd.Function とは

torch.autograd.Function は、PyTorch でカスタムな自動微分関数を作成するための基底クラスです。このクラスを継承することで、任意のテンソル操作をラップし、その操作に対する勾配を自動的に計算することができます。

backward() メソッド

backward() メソッドは、torch.autograd.Function のサブクラスで実装する必要があり、計算グラフを遡って各テンソルの勾配を計算します。このメソッドは、以下の引数を取ります。

  • grad: 出力テンソルの勾配

backward() メソッドは、以下の処理を行います。

  1. 計算グラフを遡り、各テンソルの入力テンソルに対する勾配を計算します。
  2. 計算された勾配を各入力テンソルの .grad 属性に格納します。

例:カスタム関数の勾配計算

以下は、torch.autograd.Function を継承したカスタム関数と、その backward() メソッドの例です。この関数は、2つの入力テンソルの総和を計算し、その出力テンソルに対する勾配を返します。

import torch
import torch.autograd as autograd

class AddFunction(autograd.Function):

    @staticmethod
    def forward(ctx, a, b):
        output = a + b
        ctx.save_for_backward(a, b)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_for_backward
        grad_a = grad_output
        grad_b = grad_output
        return grad_a, grad_b

# カスタム関数の使用例
x = torch.tensor(2.0)
y = torch.tensor(3.0)
z = AddFunction.forward(x, y)
z.backward()
print(x.grad)  # 1.0
print(y.grad)  # 1.0


カスタム関数を使用した2つのテンソルの総和

この例では、torch.autograd.Functionを継承したカスタム関数を作成し、2つの入力テンソルの総和を計算します。また、その出力テンソルに対する勾配も計算します。

import torch
import torch.autograd as autograd

class AddFunction(autograd.Function):

    @staticmethod
    def forward(ctx, a, b):
        output = a + b
        ctx.save_for_backward(a, b)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_for_backward
        grad_a = grad_output
        grad_b = grad_output
        return grad_a, grad_b

# カスタム関数の使用例
x = torch.tensor(2.0)
y = torch.tensor(3.0)
z = AddFunction.forward(x, y)
z.backward()
print(x.grad)  # 1.0
print(y.grad)  # 1.0

ReLU関数の実装

この例では、ReLU関数のカスタム実装と、そのbackward()メソッドを示します。

import torch
import torch.autograd as autograd

class ReLUFunction(autograd.Function):

    @staticmethod
    def forward(ctx, input):
        if input < 0:
            output = 0
        else:
            output = input
        ctx.save_for_backward(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_for_backward
        if input < 0:
            grad_input = 0
        else:
            grad_input = grad_output
        return grad_input

# ReLU関数の使用例
x = torch.tensor([-1.0, 0.0, 2.0])
relu = ReLUFunction.forward(x)
relu.backward()
print(x.grad)  # 0.0, 1.0, 1.0

畳み込み層の実装

この例では、簡単な畳み込み層のカスタム実装と、そのbackward()メソッドを示します。

import torch
import torch.autograd as autograd

class Conv2DFunction(autograd.Function):

    @staticmethod
    def forward(ctx, input, kernel):
        output = torch.nn.functional.conv2d(input, kernel)
        ctx.save_for_backward(input, kernel, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, kernel, output = ctx.saved_for_backward
        grad_input = torch.nn.functional.conv2d_transpose(grad_output, kernel)
        grad_kernel = torch.nn.functional.conv2d(input, grad_output)
        return grad_input, grad_kernel

# 畳み込み層の使用例
input = torch.randn(1, 1, 3, 3)
kernel = torch.randn(1, 1, 3, 3)
output = Conv2DFunction.forward(input, kernel)
output.backward()
print(input.grad)  # ...
print(kernel.grad)  # ...

これらの例は、PyTorchにおける自動微分とtorch.autograd.Function.backward()メソッドの使用方法を理解するための出発点として役立ちます。



書籍

ブログ記事とフォーラム

チュートリアルとビデオ

これらの追加リソースが、PyTorchにおける自動微分とtorch.autograd.Function.backward()メソッドを理解するのに役立つことを願っています。