`torch.autograd.Function.forward()`の代替手段:ベクトル化と独立メソッドでコードをもっと柔軟に


PyTorchは、機械学習タスクに広く利用されている強力なオープンソースのライブラリです。その中でも、自動微分は、ニューラルネットワークの学習において重要な役割を果たします。この機能は、torch.autograd.Functionクラスを用いて実装されており、その中でもforward()メソッドは、計算グラフの構築と出力計算に深く関わっています。

自動微分とは?

自動微分は、数値計算で関数の微分値を自動的に求める手法です。従来の数値微分と異なり、手動で微分式を導出する必要がなく、コードを簡潔に記述できます。

torch.autograd.Functionとは?

torch.autograd.Functionは、PyTorchにおける自動微分の基盤となるクラスです。このクラスを継承することで、カスタムな積和演算活性化関数などを定義することができます。

forward()メソッドとは?

forward()メソッドは、torch.autograd.Functionクラスの重要なメソッドであり、以下の役割を果たします。

  1. 計算の実行: 入力テンソルを受け取り、計算を実行して出力を返します。
  2. 計算グラフの構築: 計算過程を記録し、勾配計算に必要な計算グラフを構築します。

forward()メソッドの例

以下は、torch.autograd.Functionを継承した単純な加算関数の例です。

class AddFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        """加算関数のforwardメソッド

        Args:
            ctx: コンテキストオブジェクト
            a: 入力テンソル1
            b: 入力テンソル2

        Returns:
            c: 加算結果のテンソル
        """
        c = a + b
        ctx.save_for_backward(a, b)  # 勾配計算に必要な入力を保存
        return c

    @staticmethod
    def backward(ctx, grad_output):
        """加算関数のbackwardメソッド

        Args:
            grad_output: 出力の勾配

        Returns:
            grad_a: 入力aに対する勾配
            grad_b: 入力bに対する勾配
        """
        a, b = ctx.saved_tensors  # 保存した入力を取得
        grad_a = grad_output
        grad_b = grad_output
        return grad_a, grad_b

この例では、forward()メソッドは入力テンソル ab を加算し、結果をテンソル c として返します。また、ctx.save_for_backward() を用いて、勾配計算に必要な ab を保存します。

backward()メソッドは、出力テンソルの勾配 grad_output を受け取り、入力テンソル ab に対する勾配を計算して返します。



単純な加算関数

この例は、前述の解説で紹介した加算関数のコードをより詳細に説明します。

import torch

class AddFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        """加算関数のforwardメソッド

        Args:
            ctx: コンテキストオブジェクト
            a: 入力テンソル1
            b: 入力テンソル2

        Returns:
            c: 加算結果のテンソル
        """
        c = a + b
        ctx.save_for_backward(a, b)  # 勾配計算に必要な入力を保存
        return c

    @staticmethod
    def backward(ctx, grad_output):
        """加算関数のbackwardメソッド

        Args:
            grad_output: 出力の勾配

        Returns:
            grad_a: 入力aに対する勾配
            grad_b: 入力bに対する勾配
        """
        a, b = ctx.saved_tensors  # 保存した入力を取得
        grad_a = grad_output
        grad_b = grad_output
        return grad_a, grad_b

# 加算関数の使用例
x = torch.tensor(2.0)
y = torch.tensor(3.0)

z = AddFunction.apply(x, y)  # forward()メソッドを呼び出し
print(z)  # 出力: tensor(5.0)

# 勾配計算
z.backward()
print(x.grad)  # 出力: tensor(1.)
print(y.grad)  # 出力: tensor(1.)

ReLU関数の実装

この例では、活性化関数の一つであるReLU関数をtorch.autograd.Functionを用いて実装します。

import torch

class ReluFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        """ReLU関数のforwardメソッド

        Args:
            ctx: コンテキストオブジェクト
            input: 入力テンソル

        Returns:
            output: ReLU関数で処理された出力テンソル
        """
        output = input.clamp(min=0)
        ctx.save_for_backward(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """ReLU関数のbackwardメソッド

        Args:
            grad_output: 出力の勾配

        Returns:
            grad_input: 入力に対する勾配
        """
        input = ctx.saved_tensors[0]
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

# ReLU関数の使用例
x = torch.tensor([-1.0, 0.0, 2.0])

relu = ReluFunction.apply(x)
print(relu)  # 出力: tensor([0., 0., 2.])

# 勾配計算
relu.backward()
print(x.grad)  # 出力: tensor([0., 1., 1.])

カスタム演算の実装

この例では、2つのテンソルの内積を計算するカスタム演算をtorch.autograd.Functionを用いて実装します。

import torch

class DotProductFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        """内積関数のforwardメソッド

        Args:
            ctx: コンテキストオブジェクト
            a: 入力テンソル1
            b: 入力テンソル2

        Returns:
            inner_product: 内積の結果のスカラー値
        """
        inner_product = torch.dot(a, b)
        ctx.save_for_backward(a, b)
        return inner_product

    @staticmethod
    def backward(ctx, grad_output):
        """内積関数のbackwardメソッド

        Args:
            grad_output: 出力の勾配 (スカラー値)

        Returns:
            grad_a: 入力aに対する勾配
            grad_b: 入力bに対する勾配
        """
        a, b = ctx.saved_tensors
        grad_a


独立した forward() メソッドと setup_context() メソッド

従来の forward() メソッドは、コンテキストオブジェクト ctx を引数として受け取りましたが、PyTorch 2.0以降では ctx を受け取らず、独立したメソッドとして定義する必要があります。

さらに、計算グラフの構築とコンテキストオブジェクトの操作を行う setup_context() メソッドを別途定義する必要があります。

import torch

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(self, a, b):
        # 計算を実行し、出力を返す
        output = a + b
        return output

    @staticmethod
    def setup_context(self, a, b):
        # 計算グラフを構築し、コンテキストオブジェクトを操作
        self.save_for_backward(a, b)

# MyFunction の使用例
x = torch.tensor(2.0)
y = torch.tensor(3.0)

f = MyFunction()
z = f(x, y)
print(z)  # 出力: tensor(5.0)

# 勾配計算
z.backward()
print(x.grad)  # 出力: tensor(1.)
print(y.grad)  # 出力: tensor(1.)

torch.func.vmap() を用いたベクトル化

torch.func.vmap() 関数を使用すると、torch.autograd.Function をベクトル化し、テンソルではなくベクトルや多次元配列に対して計算を実行することができます。

この方法では、forward() メソッドと backward() メソッドを従来通り定義し、torch.func.vmap() でラップすることでベクトル化を行います。

import torch
import torch.func as F

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(self, a, b):
        # 計算を実行し、出力を返す
        output = a + b
        return output

    @staticmethod
    def backward(self, grad_output):
        # 勾配を計算し、返す
        grad_a = grad_output
        grad_b = grad_output
        return grad_a, grad_b

# MyFunction のベクトル化
vmapped_func = F.vmap(MyFunction)

# ベクトル化された関数の使用例
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([2.0, 3.0, 4.0])

z = vmapped_func(x, y)
print(z)  # 出力: tensor([3., 5., 7.])

# 勾配計算
z.backward()
print(x.grad)  # 出力: tensor([1., 1., 1.])
print(y.grad)  # 出力: tensor([1., 1., 1.])
  • ベクトリ化や多次元配列に対する計算が必要な場合は、torch.func.vmap() を用いる方が簡潔で効率的です。
  • 従来の torch.autograd.Function の使い勝手に慣れている場合は、独立した forward() メソッドと setup_context() メソッド を使用する方が良いでしょう。