TensorFlowのtf.concatを徹底解説!多次元テンソルの結合方法

2025-05-31

tf.concat は、TensorFlowにおいて、指定した次元に沿って複数のテンソルを結合(連結)するための関数です。直訳すると「テンソルを連結する」という意味になります。

基本的な使い方

import tensorflow as tf

# 結合したいテンソルを定義
tensor1 = tf.constant([[1, 2], [3, 4]])
tensor2 = tf.constant([[5, 6], [7, 8]])

# 0番目の軸(行方向)に沿って結合
concatenated_tensor_row = tf.concat([tensor1, tensor2], axis=0)
print(concatenated_tensor_row)
# 出力: tf.Tensor(
# [[1 2]
#  [3 4]
#  [5 6]
#  [7 8]], shape=(4, 2), dtype=int32)

# 1番目の軸(列方向)に沿って結合
concatenated_tensor_col = tf.concat([tensor1, tensor2], axis=1)
print(concatenated_tensor_col)
# 出力: tf.Tensor(
# [[1 2 5 6]
#  [3 4 7 8]], shape=(2, 4), dtype=int32)

重要なポイント

  • 入力
    結合したいテンソルのリストを最初の引数として渡します。
  • axis パラメータ
    どの次元に沿って結合するかを指定します。
    • axis=0: 最初の次元(通常は行またはバッチの次元)に沿って結合します。
    • axis=1: 二番目の次元(通常は列または特徴量の次元)に沿って結合します。
    • 多次元テンソルでは、axis=n は n+1番目の次元を指します。負の値を指定することもでき、その場合は後ろからのインデックスになります(例: axis=-1 は最後の次元)。
  • 結合するテンソルの形状
    結合する軸以外の次元のサイズは、すべてのテンソルで一致している必要があります。例えば、上記の例では、tensor1tensor2 はどちらも (2, 2) の形状を持っています。行方向に結合する場合 (axis=0)、列のサイズ (2) が一致している必要があります。列方向に結合する場合 (axis=1)、行のサイズ (2) が一致している必要があります。

どのような場面で使うか?

tf.concat は、以下のような様々な場面で役立ちます。

  • シーケンスデータの結合
    可変長のシーケンスデータを処理する際に、パディングされたデータを元のシーケンス長に戻すなどの処理を行う場合。
  • 特徴量の結合
    複数の異なる特徴量セットを一つのテンソルに結合して、モデルへの入力とする場合。


tf.concat を使う上で最も一般的なエラーは、結合しようとするテンソルの形状(shape)が非互換であることです。

よくあるエラー

    • 原因
      このエラーは、結合しようとしているテンソルの指定した軸以外の次元のサイズが一致していない場合に発生します。例えば、行方向に結合する場合(axis=0)、すべてのテンソルの列の数が同じでなければなりません。同様に、列方向に結合する場合(axis=1)、すべてのテンソルの行の数が同じである必要があります。
    • トラブルシューティング
      • 結合しようとしているすべてのテンソルの形状を tf.shape() などで確認し、結合する軸以外の次元のサイズがすべて一致しているかを確認してください。
      • 必要に応じて、tf.reshape()tf.expand_dims() などを使用して、テンソルの形状を調整してください。ただし、形状変更はデータの意味が変わらない範囲で行う必要があります。
  1. InvalidArgumentError: ConcatOp : Expected concatenating dimensions in the range [0, ...) but got ...

    • 原因
      axis パラメータに指定した値が、テンソルの次元数に対して無効な範囲にある場合に発生します。例えば、2次元のテンソルに対して axis=2 を指定したり、負のインデックスがテンソルの次元数を超えていたりする場合です。
    • トラブルシューティング
      • axis パラメータの値が、結合したいテンソルの次元数に対して正しい範囲内にあるか確認してください。
      • 負のインデックスを使用している場合は、それが意図した次元を指しているか確認してください(例:2次元テンソルでは、axis=-1 は最後の次元(1番目の軸)、axis=-2 は最初の次元(0番目の軸)です)。
  2. 結合するテンソルのデータ型(dtype)が異なる場合(警告またはエラー)

    • 原因
      tf.concat は異なるデータ型のテンソルを結合できますが、通常は結果のテンソルは入力テンソルのデータ型を保持するか、より一般的なデータ型にキャストされます。意図しないデータ型の変換が起こり、後の計算で問題が発生する可能性があります。古いバージョンのTensorFlowではエラーが発生する場合もありました。
    • トラブルシューティング
      • 結合する前に、すべてのテンソルのデータ型を tensor.dtype で確認してください。
      • 必要に応じて、tf.cast() を使用して、すべてのテンソルのデータ型を明示的に揃えてから結合してください。
  3. 結合するテンソルのリストが空の場合

    • 原因
      tf.concat の最初の引数として空のリスト [] を渡すと、予期しない動作やエラーが発生する可能性があります。
    • トラブルシューティング
      • 結合するテンソルが実際に存在するかどうかを確認してください。
      • 条件によっては結合するテンソルがない可能性がある場合は、その場合の処理を適切に記述してください(例えば、空のテンソルを生成するなど)。

トラブルシューティングの一般的なヒント

  • TensorFlowのバージョンを確認する
    TensorFlowのバージョンによって、挙動やエラーメッセージが異なる場合があります。使用しているバージョンを把握しておくと、情報検索や問題解決に役立ちます。
  • 小さな例で試す
    複雑な処理でエラーが発生した場合は、問題を切り分けるために、小さなテンソルで tf.concat の動作を試してみるのが有効です。
  • 形状とデータ型を常に意識する
    テンソルの形状とデータ型は、TensorFlowプログラミングにおいて非常に重要です。tf.shape()tensor.dtype を活用して、意図した通りになっているか確認しましょう。
  • エラーメッセージをよく読む
    TensorFlowのエラーメッセージは、問題の原因や場所に関する貴重な情報を提供してくれます。


基本的な結合の例

  1. 行方向への結合 (axis=0)

    import tensorflow as tf
    
    # 2x2 のテンソルを2つ作成
    tensor1 = tf.constant([[1, 2], [3, 4]])
    tensor2 = tf.constant([[5, 6], [7, 8]])
    
    # 行方向(最初の軸)に沿って結合
    concatenated_row = tf.concat([tensor1, tensor2], axis=0)
    print(f"行方向に結合:\n{concatenated_row}")
    # 出力:
    # 行方向に結合:
    # [[1 2]
    #  [3 4]
    #  [5 6]
    #  [7 8]]
    # 形状: (4, 2)
    

    この例では、2つの同じ列数を持つテンソルを行方向に連結しています。結果として、行数が増えた新しいテンソルが得られます。

  2. import tensorflow as tf
    
    # 2x2 のテンソルを2つ作成
    tensor1 = tf.constant([[1, 2], [3, 4]])
    tensor2 = tf.constant([[5, 6], [7, 8]])
    
    # 列方向(2番目の軸)に沿って結合
    concatenated_col = tf.concat([tensor1, tensor2], axis=1)
    print(f"列方向に結合:\n{concatenated_col}")
    # 出力:
    # 列方向に結合:
    # [[1 2 5 6]
    #  [3 4 7 8]]
    # 形状: (2, 4)
    

異なる形状のテンソルを結合する例 (形状が一致する必要がある次元に注意)

import tensorflow as tf

# 形状が (2, 3) と (2, 2) のテンソル
tensor3 = tf.constant([[10, 20, 30], [40, 50, 60]])
tensor4 = tf.constant([[70, 80], [90, 100]])

# 列方向 (axis=1) に結合 (行数は一致している必要がある)
concatenated_diff_col = tf.concat([tensor3, tensor4], axis=1)
print(f"異なる列数で列方向に結合:\n{concatenated_diff_col}")
# 出力:
# 異なる列数で列方向に結合:
# [[ 10  20  30  70  80]
#  [ 40  50  60  90 100]]
# 形状: (2, 5)

# 行方向 (axis=0) に結合 (列数は一致している必要がある)
tensor5 = tf.constant([[1, 2, 3]])
tensor6 = tf.constant([[4, 5, 6], [7, 8, 9]])

# これはエラーになります (列数が一致しないため)
# try:
#     concatenated_diff_row_error = tf.concat([tensor5, tensor6], axis=0)
#     print(concatenated_diff_row_error)
# except tf.errors.InvalidArgumentError as e:
#     print(f"エラー: {e}")
# 出力例:
# エラー: {{function_node __wrapped_ConcatV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Dimensions of inputs should match: shape [1,3] vs. shape [2,3] at index 0 [Op:ConcatV2]

この例では、列方向への結合では行数が一致していれば異なる列数のテンソルを結合できることを示しています。一方、行方向への結合では列数が一致していないとエラーが発生します。

3次元以上のテンソルを結合する例

import tensorflow as tf

# 形状 (2, 3, 4) のテンソルを2つ作成
tensor7 = tf.ones((2, 3, 4))
tensor8 = tf.zeros((2, 3, 4))

# 最初の軸 (axis=0) に沿って結合
concatenated_axis0 = tf.concat([tensor7, tensor8], axis=0)
print(f"axis=0 で結合 (形状: {concatenated_axis0.shape}):\n{concatenated_axis0[0, 0, :]}\n{concatenated_axis0[2, 0, :]}")
# 出力例:
# axis=0 で結合 (形状: (4, 3, 4)):
# [1. 1. 1. 1.]
# [0. 0. 0. 0.]

# 2番目の軸 (axis=1) に沿って結合
concatenated_axis1 = tf.concat([tensor7, tensor8], axis=1)
print(f"axis=1 で結合 (形状: {concatenated_axis1.shape}):\n{concatenated_axis1[0, :, 0]}")
# 出力例:
# axis=1 で結合 (形状: (2, 6, 4)):
# [1. 1. 1. 0. 0. 0.]

# 3番目の軸 (axis=2) に沿って結合
concatenated_axis2 = tf.concat([tensor7, tensor8], axis=2)
print(f"axis=2 で結合 (形状: {concatenated_axis2.shape}):\n{concatenated_axis2[0, 0, :]}")
# 出力例:
# axis=2 で結合 (形状: (2, 3, 8)):
# [1. 1. 1. 1. 0. 0. 0. 0.]

3次元以上のテンソルでも、axis パラメータでどの次元に沿って結合するかを指定できます。結合する軸以外の次元のサイズは一致している必要があります。

tf.concat の応用例

  • 特徴量の結合
    画像データの特徴量とテキストデータの特徴量を結合して、マルチモーダルな入力を作成する場合などに使用できます。
  • シーケンスデータの結合
    可変長のシーケンスデータを処理する際に、パディングされた部分を取り除いた後に元のシーケンスを再構成する場合などに使用できます。
  • バッチ処理されたデータの結合
    異なるミニバッチの結果を結合して、完全な出力を得る場合などに使用できます。


tf.stack

  • 機能
    tf.stack は、指定した軸に沿って新しい次元を作成し、複数のテンソルを積み重ねます。結合とは異なり、次元数が増える点が特徴です。

リストへの追加と tf.stack / tf.concat の組み合わせ

  • 使いどころ
    動的に生成されるテンソルを効率的に結合する場合に便利です。

    import tensorflow as tf
    
    tensors_to_concat = []
    for i in range(3):
        tensor = tf.ones((2, 2)) * i
        tensors_to_concat.append(tensor)
    
    concatenated_dynamic = tf.concat(tensors_to_concat, axis=0)
    print(f"動的に生成したテンソルを tf.concat で結合:\n{concatenated_dynamic}")
    # 出力:
    # 動的に生成したテンソルを tf.concat で結合:
    # [[0. 0.]
    #  [0. 0.]
    #  [1. 1.]
    #  [1. 1.]
    #  [2. 2.]
    #  [2. 2.]]
    # 形状: (6, 2)
    
    stacked_dynamic = tf.stack(tensors_to_concat, axis=0)
    print(f"動的に生成したテンソルを tf.stack で結合:\n{stacked_dynamic}")
    # 出力:
    # 動的に生成したテンソルを tf.stack で結合:
    # [[[0. 0.]
    #   [0. 0.]]
    #
    #  [[1. 1.]
    #   [1. 1.]]
    #
    #  [[2. 2.]
    #   [2. 2.]]]
    # 形状: (3, 2, 2)
    
  • 機能
    複数のテンソルをループなどで生成する場合、一旦Pythonのリストに格納し、最後に tf.stack または tf.concat を用いて結合する方法です。

tf.reshape との組み合わせ

  • 使いどころ
    特に1次元テンソルを結合して特定の形状のテンソルを作成したい場合などに有効です。ただし、データの順序には注意が必要です。

    import tensorflow as tf
    
    tensor_flat1 = tf.constant([1, 2, 3])
    tensor_flat2 = tf.constant([4, 5, 6])
    
    # 結合してからreshape
    concatenated_reshaped = tf.reshape(tf.concat([tensor_flat1, tensor_flat2], axis=0), (2, 3))
    print(f"結合後に reshape:\n{concatenated_reshaped}")
    # 出力:
    # 結合後に reshape:
    # [[1 2 3]
    #  [4 5 6]]
    # 形状: (2, 3)
    
  • 機能
    形状を変更してから結合することで、間接的に tf.concat と同様の効果を得られる場合があります。

tf.pad とのスライス

  • 使いどころ
    可変長のデータをバッチ処理する際など。
  • 機能
    長さの異なるシーケンスデータを扱う場合、tf.pad で長さを揃えてから tf.stack することがあります。その後、必要に応じてスライス操作で元の長さに戻すなどの処理を行います。これは直接的な代替ではありませんが、可変長シーケンスの処理でよく用いられるパターンです。

特定のレイヤーの出力結合 (Keras API)

  • 使いどころ
    ニューラルネットワークのモデル構築において、複数のパスからの情報を統合する場合など。

    import tensorflow as tf
    from tensorflow.keras import layers
    
    input1 = tf.keras.Input(shape=(32,))
    dense1 = layers.Dense(8, activation='relu')(input1)
    input2 = tf.keras.Input(shape=(64,))
    dense2 = layers.Dense(8, activation='relu')(input2)
    
    concatenated = layers.Concatenate()([dense1, dense2])
    output = layers.Dense(1)(concatenated)
    model = tf.keras.Model(inputs=[input1, input2], outputs=output)
    print(model.summary())
    
  • 機能
    Keras APIを使用している場合、tf.keras.layers.Concatenate レイヤーを使って、複数のレイヤーの出力を結合できます。