PyTorch Tensorの次元数を知る:ndimension()メソッド徹底解説


torch.Tensor.ndimension() は、PyTorch の Tensor オブジェクトの次元数を取得するためのメソッドです。Tensor は多次元配列を表しており、ndimension() メソッドはその次元数を整数值として返します。

このメソッドは、テンソルの形状やサイズを理解する際に役立ちます。例えば、テンソルの操作や可視化を行う前に、その次元数を確認する必要がある場合があります。

使用方法

torch.Tensor.ndimension() メソッドは、引数なしで呼び出します。構文は以下の通りです。

tensor_ndim = tensor.ndimension()

ここで、

  • tensor は、次元数を取得したい Tensor オブジェクトです。
  • tensor_ndim は、Tensor オブジェクトの次元数を格納する整数変数です。

以下の例は、torch.Tensor.ndimension() メソッドの使用方法を示します。

import torch

# 1 次元テンソルを作成
x = torch.tensor([1, 2, 3])

# 次元数を取得
x_ndim = x.ndimension()

# 結果を出力
print(f"テンソルの次元数: {x_ndim}")

このコードを実行すると、以下の出力が得られます。

テンソルの次元数: 1

上記の例では、x は 1 次元テンソルなので、ndimension() メソッドは 1 を返します。

  • 関連するメソッドとして、torch.Tensor.size() メソッドがあります。このメソッドは、各次元のサイズを含むタプルを返します。
  • torch.Tensor.ndimension() メソッドは、torch.Tensor.ndim 属性と同じ値を返します。


import torch

# 2 次元テンソルを作成
x = torch.randn(3, 4)

# 次元数とサイズを取得
x_ndim = x.ndimension()
x_size = x.size()

# 結果を出力
print(f"テンソルの次元数: {x_ndim}")
print(f"テンソルのサイズ: {x_size}")
テンソルの次元数: 2
テンソルのサイズ: (3, 4)

この例では、x は 2 次元テンソルなので、ndimension() メソッドは 2 を返し、size() メソッドは (3, 4) というタプルを返します。タプル内の要素は、各次元のサイズを表します。

以下のコードは、torch.Tensor.ndimension() メソッドの様々な使用方法を示しています。

import torch

# 0 次元テンソルを作成
scalar = torch.tensor(5)

# 次元数を取得
scalar_ndim = scalar.ndimension()

# 結果を出力
print(f"スカラーの次元数: {scalar_ndim}")

# 3 次元テンソルを作成
y = torch.randn(5, 6, 7)

# 次元数を取得
y_ndim = y.ndimension()

# 結果を出力
print(f"テンソルの次元数: {y_ndim}")
スカラーの次元数: 0
テンソルの次元数: 3

この例では、scalar は 0 次元テンソル(スカラー)なので、ndimension() メソッドは 0 を返します。y は 3 次元テンソルなので、ndimension() メソッドは 3 を返します。



torch.Tensor.ndim 属性

torch.Tensor.ndim 属性は、torch.Tensor.ndimension() メソッドと同じ値を返す属性です。構文は以下の通りです。

tensor_ndim = tensor.ndim

この方法は、torch.Tensor.ndimension() メソッドよりも簡潔に記述できます。

len() 関数

len() 関数は、リストやタプルの長さを取得する関数ですが、Tensor オブジェクトにも適用できます。Tensor オブジェクトの場合、len() 関数は次元数を返します。構文は以下の通りです。

tensor_ndim = len(tensor)

この方法は、シンプルで分かりやすい方法ですが、テンソル以外のオブジェクトに対して使用するとエラーが発生する可能性があることに注意する必要があります。

属性検査

以下のコードのように、属性検査を使用して次元数を取得することもできます。

if hasattr(tensor, "ndim"):
    tensor_ndim = tensor.ndim
else:
    tensor_ndim = 0

この方法は、古いバージョンの PyTorch で動作させる必要がある場合に役立ちます。

方法説明長所短所
torch.Tensor.ndimension()メソッド形式で次元数を取得明確で分かりやすいやや冗長
torch.Tensor.ndim 属性属性形式で次元数を取得簡潔古いバージョンの PyTorch ではサポートされていない可能性がある
len() 関数関数形式で次元数を取得シンプルで分かりやすいテンソル以外のオブジェクトに対して使用するとエラーが発生する可能性がある
属性検査属性検査を使用して次元数を取得古いバージョンの PyTorch で動作可能やや冗長で分かりにくい