BLAS.gemm!()だけじゃない!Juliaで行列積を計算する代替手法を徹底解説

2025-05-26

名前に付いている!は、Juliaの慣習で、関数が引数を変更する(破壊的な操作を行う)ことを示します。

gemm!() の機能

gemm は "General Matrix-Matrix Multiplication" の略です。一般的に、以下の形式の計算を実行します。

C:=α⋅op(A)⋅op(B)+β⋅C

ここで:

  • op(X):行列Xに対して転置(transpose)操作を行うかどうかを指定します。転置しない場合はXそのもの、転置する場合はXT(または複素共役転置XH)になります。
  • α, β:スカラー(数値)
  • A, B:入力行列
  • C:更新される行列(出力行列)

gemm!() の引数

基本的なgemm!()のシグネチャは以下のようになります。

gemm!(tA, tB, alpha, A, B, beta, C)

それぞれの引数の意味は以下の通りです。

  • C:出力行列。計算結果がこの行列に上書きされます。そのため、事前に適切なサイズで確保されている必要があります。

  • beta:スカラー値 β。既存の行列Cに乗算されます。Float64, Float32, Complex128, Complex64などの数値型を指定します。

  • B:入力行列B。

  • A:入力行列A。

  • alpha:スカラー値 α。行列op(A)⋅op(B)に乗算されます。Float64, Float32, Complex128, Complex64などの数値型を指定します。

  • tB:行列Bに対する操作を指定する文字。tAと同様。

  • tA:行列Aに対する操作を指定する文字。

    • 'N' (または 'n'):Aをそのまま使用(No transpose)
    • 'T' (または 't'):Aの転置を使用(Transpose)
    • 'C' (または 'c'):Aの共役転置を使用(Conjugate transpose、主に複素数行列用)

gemm!() を使うメリット

  1. パフォーマンス: BLASは、CPUの特性に最適化された低レベルな線形代数ルーチンを提供します。gemm!()を使用することで、Juliaの通常の行列乗算(例: C = A * B)よりも高速な計算が期待できます。特に大規模な行列の計算で顕著な差が出ます。
  2. メモリ効率: !が示す通り、結果を既存の行列Cに上書きするため、新しいメモリを割り当てる必要がありません。ループ内で繰り返し行列計算を行う場合など、頻繁なメモリ割り当てと解放(ガベージコレクション)によるオーバーヘッドを避けることができ、メモリ使用量と実行時間の両面で効率的です。
using LinearAlgebra

# 3x2行列 A を作成
A = [1.0 2.0; 3.0 4.0; 5.0 6.0]

# 2x3行列 B を作成
B = [7.0 8.0 9.0; 10.0 11.0 12.0]

# 3x3行列 C を作成(計算結果が格納される)
C = zeros(3, 3)

# C = 1.0 * A * B + 0.0 * C を計算(つまり C = A * B)
BLAS.gemm!('N', 'N', 1.0, A, B, 0.0, C)

println("A:")
println(A)
println("\nB:")
println(B)
println("\nC (A * B の結果):")
println(C)

# 別の例: C = 2.0 * A' * B + 1.0 * C (既存の C に加算)
# A' は 2x3行列
A_prime = A' # Aの転置
# B は 2x3行列
# C は 3x3行列だが、ここでは A' * B の結果は 3x3 にならないため、新しいCを用意
C_new = zeros(3,3)

# C_new = 2.0 * A' * B + 0.0 * C_new
BLAS.gemm!('T', 'N', 2.0, A, B, 0.0, C_new) # Aの転置を使うので 'T'

println("\nAの転置 (A'):")
println(A_prime)
println("\nC_new (2.0 * A' * B の結果):")
println(C_new)

この例では、'N'は「転置なし」、'T'は「転置あり」を意味します。alphabetaに0.0を指定することで、対応する項を無視できます(例: beta=0.0は既存のCの値を無視し、alpha * op(A) * op(B)の結果をCに直接書き込む)。



ディメンションの不一致 (Dimension Mismatch)

これは最も一般的なエラーです。行列の乗算には特定のディメンションのルールがあります。

  • 出力行列 C のサイズは、 op(A)⋅op(B) の結果行列のサイズと完全に一致する必要があります。
  • op(A)⋅op(B) の結果行列のサイズは、 op(A) の行数と op(B) の列数になります。
  • op(A) の列数と op(B) の行数は一致する必要があります。

エラーメッセージの例
DimensionMismatch("matrix A has dimensions (3,2) but matrix B has dimensions (4,3)") (これは直接gemm!()のエラーではないかもしれませんが、原因は同じです) ArgumentError: matrix C has dimensions (3,3) but expected (3,2)

トラブルシューティング

  1. 転置オプション (tA, tB) を確認する

    • tA = 'N' (転置なし): Aのサイズがそのまま使われます。
    • tA = 'T' (転置あり): Aのサイズが転置されて使われます(例: (m, n)(n, m)になる)。
    • tA = 'C' (共役転置あり): 複素数行列の場合に共役転置が適用されます。 これらのオプションを考慮した上で、行列ABの有効なディメンションを確認してください。

    例: A(m, k)B(k, n) の場合、gemm!('N', 'N', alpha, A, B, beta, C)C(m, n)になります。 もしA(k, m)tA='T'B(k, n)tB='N'の場合、A(m, k)として扱われ、B(k, n)として扱われるので、C(m, n)になります。

  2. C のサイズを事前に確認する
    Cは結果を格納するための行列であり、適切なサイズで事前に確保されている必要があります。zeros(rows, cols)similar(another_matrix)などで正確なサイズの行列を準備してください。

データ型 (Data Types) の不一致

BLAS.gemm!()は通常、Float64Float32ComplexF64ComplexF32などの標準的な浮動小数点型または複素数型に特化して最適化されています。異なる数値型(例: IntBigFloatなど)を使用しようとすると、MethodErrorや非効率なフォールバックが発生する可能性があります。

エラーメッセージの例
MethodError: no method matching gemm!(::Char, ::Char, ::Int64, ::Array{Int64,2}, ::Array{Int64,2}, ::Int64, ::Array{Int64,2})

トラブルシューティング

  1. 引数の型を確認する
    alpha, beta、および行列A, B, Cの要素型がBLASがサポートする型であることを確認してください。整数を渡したい場合でも、1.0のように浮動小数点数として指定します。

    # 悪い例 (Int型)
    A = [1 2; 3 4]
    B = [5 6; 7 8]
    C = zeros(Int, 2, 2) # CもInt型
    # BLAS.gemm!('N', 'N', 1, A, B, 0, C) # MethodError
    
    # 良い例 (Float64型)
    A = [1.0 2.0; 3.0 4.0]
    B = [5.0 6.0; 7.0 8.0]
    C = zeros(2, 2) # デフォルトでFloat64
    BLAS.gemm!('N', 'N', 1.0, A, B, 0.0, C)
    

メモリの連続性 (Memory Contiguity)

BLASルーチンは、メモリ上で連続的に配置されたデータ(Dense Array)に対して最適化されています。@viewなどを使用して、メモリが連続していない配列の一部(Strided Array)をgemm!()に渡そうとすると、エラーになったり、パフォーマンスが低下したりする可能性があります。

エラーメッセージの例
ERROR: matrix does not have contiguous columns

トラブルシューティング

  1. @view の使用に注意する
    BLAS.gemm!()に直接@viewを渡すのは避けるべきです。特に、列が連続していないようなビュー(例: A[1:2:end, :]のように行方向に飛び飛びでアクセスするビュー)はBLASにとって問題となることがあります。

  2. mul! を検討する
    JuliaのLinearAlgebra.mul!関数は、より高レベルなインターフェースを提供し、内部で適切にBLASを呼び出すか、あるいは非連続メモリに対しても効率的に処理を行います。多くの場合、BLAS.gemm!()を直接呼び出す代わりにmul!(C, A, B, alpha, beta)を使用する方が安全で推奨されます。mul!は必要に応じて一時的なメモリを確保してBLASが扱える形式に変換してくれます。

    using LinearAlgebra
    
    A = rand(10, 10)
    B = rand(10, 10)
    C = zeros(10, 10)
    
    # 悪い例 (gemm!はViewを直接扱えない場合がある)
    # BLAS.gemm!('N', 'N', 1.0, @view(A[1:2:end, :]), B, 0.0, C[1:5, :]) # エラーや非効率の原因に
    
    # 良い例 (mul!はViewを適切に処理する)
    mul!(C, A, B) # C = A * B と同じ
    mul!(C, A', B, 2.0, 1.0) # C = 2.0 * A' * B + 1.0 * C
    

稀に、Juliaが使用しているBLASライブラリ自体に問題がある場合があります(例: OpenBLASのバージョン問題やMKLなどのカスタムBLASライブラリの設定ミス)。

エラーメッセージの例
could not find function :zgemm_64_ in library libopenblas64_ undefined reference to 'cblas_ddot'

トラブルシューティング

  1. Juliaの再インストールまたはBLASの再構築
    通常のユーザーであれば、Juliaのインストール時にバンドルされているOpenBLASを使用しているはずです。もしJuliaのインストール自体に問題があったり、手動でBLASライブラリを入れ替えたりしている場合は、再インストールや設定の見直しが必要になることがあります。
  2. Juliaのバージョンアップ
    特定のBLASに関するバグは、新しいJuliaのバージョンで修正されている場合があります。
  3. Intel MKLなどのカスタムBLASを使用している場合
    using MKLを適切に呼び出しているか、環境変数が正しく設定されているかを確認します。

LinearAlgebra.BLAS.gemm!()は非常に強力で高性能なツールですが、その低レベルな性質上、引数の型、次元、メモリレイアウトに厳密な要件があります。 多くの場合、直接BLAS.gemm!()を呼び出す代わりに、JuliaのLinearAlgebraモジュールが提供する高レベルな関数(特にmul!)を使用することを強くお勧めします。mul!は内部で最適なBLASルーチンを選択し、エラーを自動的に処理したり、より柔軟な引数型をサポートしたりするため、開発者は低レベルな詳細を気にせずに済みます。



基本的な行列積の計算 (C=A⋅B)

最も基本的な使用例です。α=1.0, β=0.0 と設定することで、C の初期値に関わらず C=A⋅B が計算されます。

using LinearAlgebra

println("--- 例1: 基本的な行列積 (C = A * B) ---")

# 2x3行列 A
A = [1.0 2.0 3.0;
     4.0 5.0 6.0]

# 3x2行列 B
B = [7.0 8.0;
     9.0 10.0;
     11.0 12.0]

# 結果を格納する2x2行列 C を初期化
# BLAS.gemm! は結果を C に上書きするため、C は適切なサイズで事前に用意されている必要があります。
C = zeros(2, 2)

println("行列 A:\n", A)
println("行列 B:\n", B)
println("初期の行列 C:\n", C)

# C = 1.0 * A * B + 0.0 * C を実行
# 'N': 行列Aを転置しない (No transpose)
# 'N': 行列Bを転置しない (No transpose)
# 1.0: スカラー alpha
# A: 行列A
# B: 行列B
# 0.0: スカラー beta
# C: 結果を格納する行列C
BLAS.gemm!('N', 'N', 1.0, A, B, 0.0, C)

println("\n計算後の行列 C (A * B の結果):\n", C)

# 検証(通常の行列積と比較)
C_expected = A * B
println("\n検証 (A * B と直接計算):\n", C_expected)
@assert C ≈ C_expected
println("結果が一致しました。\n")

転置を使った行列積の計算 (C=AT⋅B)

tA または tB の引数を 'T' に設定することで、行列の転置を適用できます。

using LinearAlgebra

println("--- 例2: 転置を含む行列積 (C = A' * B) ---")

# 3x2行列 A
A = [1.0 2.0;
     3.0 4.0;
     5.0 6.0]

# 3x4行列 B
B = [1.0 2.0 3.0 4.0;
     5.0 6.0 7.0 8.0;
     9.0 10.0 11.0 12.0]

# A' は 2x3行列
# A' * B は 2x4行列になるため、C を 2x4 で初期化
C = zeros(2, 4)

println("行列 A:\n", A)
println("行列 B:\n", B)
println("初期の行列 C:\n", C)

# C = 1.0 * A' * B + 0.0 * C を実行
# 'T': 行列Aを転置する (Transpose)
# 'N': 行列Bを転置しない (No transpose)
BLAS.gemm!('T', 'N', 1.0, A, B, 0.0, C)

println("\n計算後の行列 C (A' * B の結果):\n", C)

# 検証
C_expected = A' * B
println("\n検証 (A' * B と直接計算):\n", C_expected)
@assert C ≈ C_expected
println("結果が一致しました。\n")

スカラー乗算と加算を含む計算 (C:=α⋅A⋅B+β⋅C)

alphabeta の値を 1.00.0 以外に設定することで、より一般的な線形変換を表現できます。

using LinearAlgebra

println("--- 例3: スカラー乗算と既存行列への加算 (C := alpha * A * B + beta * C) ---")

# 2x2行列 A
A = [1.0 2.0;
     3.0 4.0]

# 2x2行列 B
B = [5.0 6.0;
     7.0 8.0]

# 2x2行列 C (初期値を持つ)
C = [10.0 20.0;
     30.0 40.0]

alpha = 2.0
beta  = 0.5

println("行列 A:\n", A)
println("行列 B:\n", B)
println("初期の行列 C:\n", C)
println("alpha = ", alpha)
println("beta  = ", beta)

# C := alpha * A * B + beta * C を実行
BLAS.gemm!('N', 'N', alpha, A, B, beta, C)

println("\n計算後の行列 C (alpha * A * B + beta * C の結果):\n", C)

# 検証
C_expected = alpha * (A * B) + beta * [10.0 20.0; 30.0 40.0] # 初期Cを明示的に使う
println("\n検証 (手動計算):\n", C_expected)
@assert C ≈ C_expected
println("結果が一致しました。\n")

複素数行列の計算

BLAS.gemm!() は複素数行列もサポートします。'C' オプションは共役転置 (conjugate transpose) を意味します。

using LinearAlgebra

println("--- 例4: 複素数行列の計算 (C = A * B') ---")

# 2x2複素数行列 A
A = [1.0 + 1.0im  2.0 - 1.0im;
     3.0 + 2.0im  4.0 - 2.0im]

# 2x2複素数行列 B
B = [5.0 + 3.0im  6.0 - 3.0im;
     7.0 + 4.0im  8.0 - 4.0im]

# 結果を格納する2x2複素数行列 C
C = zeros(ComplexF64, 2, 2)

println("行列 A:\n", A)
println("行列 B:\n", B)
println("初期の行列 C:\n", C)

# C = 1.0 * A * B' + 0.0 * C を実行
# 'N': Aは転置しない
# 'C': Bは共役転置する
BLAS.gemm!('N', 'C', 1.0 + 0.0im, A, B, 0.0 + 0.0im, C) # スカラーも複素数型に合わせる

println("\n計算後の行列 C (A * B' の結果):\n", C)

# 検証
C_expected = A * B'
println("\n検証 (A * B' と直接計算):\n", C_expected)
@assert C ≈ C_expected
println("結果が一致しました。\n")

通常、Juliaで線形代数演算を行う場合は、LinearAlgebra モジュールが提供する高レベルな関数 (mul!, *, factorize など) を使用することが推奨されます。これらの関数は、内部で最適なBLASルーチンを呼び出し、引数の型チェックやメモリ管理などを自動的に行ってくれます。

mul!()BLAS.gemm!() の高レベルなラッパーであり、より柔軟で使いやすいです。

using LinearAlgebra

println("--- 例5: BLAS.gemm!() と mul!() の比較 ---")

A = rand(5, 3)
B = rand(3, 4)
C_gemm = zeros(5, 4)
C_mul = zeros(5, 4)

println("行列 A のサイズ: ", size(A))
println("行列 B のサイズ: ", size(B))
println("行列 C のサイズ: ", size(C_gemm))

# BLAS.gemm!() を使用
BLAS.gemm!('N', 'N', 1.0, A, B, 0.0, C_gemm)
println("\nBLAS.gemm! による結果の最初の行:\n", C_gemm[1,:])

# mul!() を使用
# mul!(C, A, B) は C = A * B を in-place で計算
mul!(C_mul, A, B)
println("\nmul! による結果の最初の行:\n", C_mul[1,:])

@assert C_gemm ≈ C_mul
println("BLAS.gemm! と mul! の結果が一致しました。\n")

# mul! は alpha, beta 引数も取れる
C_mul_weighted = zeros(5, 4)
mul!(C_mul_weighted, A, B, 2.0, 1.0) # C_mul_weighted = 2.0 * A * B + 1.0 * C_mul_weighted
println("\nmul! (重み付き) による結果の最初の行:\n", C_mul_weighted[1,:])

# BLAS.gemm! で同じ計算
C_gemm_weighted = zeros(5, 4)
BLAS.gemm!('N', 'N', 2.0, A, B, 1.0, C_gemm_weighted)
println("\nBLAS.gemm! (重み付き) による結果の最初の行:\n", C_gemm_weighted[1,:])

@assert C_mul_weighted ≈ C_gemm_weighted
println("mul! (重み付き) と BLAS.gemm! (重み付き) の結果が一致しました。\n")

LinearAlgebra.BLAS.gemm!() は、特にパフォーマンスが重視される線形代数計算において、低レベルな制御を可能にする強力なツールです。しかし、引数の型、次元、転置オプションの指定には細心の注意が必要です。



* 演算子 (通常の行列乗算)

最も一般的で直感的な方法です。A * B と書くだけで行列積が計算され、新しい行列が生成されます。

using LinearAlgebra

A = rand(3, 2)
B = rand(2, 4)

# C = A * B を計算
C = A * B

println("--- 1. * 演算子 ---")
println("A:\n", A)
println("B:\n", B)
println("C = A * B:\n", C)

# 転置も簡単
C_transposed = A' * B
println("C = A' * B:\n", C_transposed)

特徴

  • 内部最適化
    Juliaの*演算子は、内部でBLASのgemmルーチンを呼び出すなど、可能な限り最適化された方法で計算を実行します。
  • 新しいメモリ割り当て
    計算結果を格納するために新しいメモリが割り当てられます。ループ内で頻繁に呼び出す場合、ガベージコレクションのオーバーヘッドが生じる可能性があります。
  • シンプルさ
    最も読みやすく、書きやすい。

mul!() 関数 (in-place 行列乗算)

BLAS.gemm!() と同様に、結果を既存の行列に上書きする(in-place)形式の行列乗算です。gemm!よりも高レベルなインターフェースを提供し、より汎用的に設計されています。

using LinearAlgebra

A = rand(3, 2)
B = rand(2, 4)
C = zeros(3, 4) # 結果を格納する行列を事前に用意

println("\n--- 2. mul!() 関数 ---")
println("初期の C:\n", C)

# C = A * B を計算(in-place)
mul!(C, A, B)
println("mul!(C, A, B) 後:\n", C)

# スカラー乗算と加算もサポート
# C := alpha * A * B + beta * C
alpha = 2.0
beta = 0.5
C_initial = deepcopy(C) # 既存のCの値を残すためにコピー
mul!(C, A, B, alpha, beta)
println("\nmul!(C, A, B, ", alpha, ", ", beta, ") 後 (初期C:\n", C_initial, "):\n", C)

# 転置も直接指定可能
C_transposed_mul = zeros(2, 4)
mul!(C_transposed_mul, A', B) # A' は A の転置
println("\nmul!(C, A', B) 後:\n", C_transposed_mul)

特徴

  • 推奨される代替手段
    BLAS.gemm!()を直接使用する代わりに、ほとんどの場面でmul!()が推奨されます。
  • 自動最適化
    内部で最適なBLAS/LAPACKルーチンが選択されます。gemm!が適用できないような特殊なケース(例: ストライド配列など)でも適切に処理されます。
  • 高レベルなインターフェース
    BLAS.gemm!()よりも引数の指定が直感的で、エラーハンドリングも強化されています。例えば、転置はA'のように自然な構文で指定できます。
  • in-place計算
    新しいメモリ割り当てを避けるため、メモリ効率が良いです。特に大きな行列や、ループ内で繰り返し計算を行う場合に有利です。

Array のブロードキャスト演算 (Broadcasting)

行列積そのものではありませんが、要素ごとの積(Hadamard積)や、スカラーとの積、加算などはブロードキャストを使って効率的に行えます。これは gemm!() が提供する C:=α⋅op(A)⋅op(B)+β⋅C の後半部分 (β⋅C) に近い操作です。

using LinearAlgebra

C = rand(3, 3)
beta = 0.5

println("\n--- 3. Array のブロードキャスト演算 ---")
println("初期の C:\n", C)

# Cの各要素にbetaを乗算(in-place)
C .*= beta
println("C .*= beta 後:\n", C)

# 別途計算した行列 D を加算(in-place)
D = rand(3, 3) * 10
C .+= D
println("C .+= D 後 (D:\n", D, "):\n", C)

特徴

  • 効率的
    Juliaのブロードキャストは非常に最適化されており、ループと同等かそれ以上のパフォーマンスを発揮することが多いです。
  • 非常に柔軟
    複雑な要素ごとの計算を簡潔に記述できます。
  • 要素ごとの操作
    行列積とは異なる操作(要素ごとの演算)ですが、線形代数演算の一部として利用できます。

高度な線形代数パッケージ

Juliaのエコシステムには、より専門的な線形代数計算のためのパッケージも存在します。

  • GPUArrays.jl / CUDA.jl / AMDGPU.jl
    GPU上で線形代数計算を行うためのパッケージです。GPUBLAS (cuBLASなど) を利用して、非常に高速な行列積計算を実現します。
  • SuiteSparse.jl
    疎行列の様々なアルゴリズムを提供します。
  • SparseArrays.jl
    疎行列(多くの要素がゼロの行列)の計算に特化しています。疎行列の積は、通常のBLASのgemmとは異なるアルゴリズムで最適化されます。

例 (概念的)

using SparseArrays

println("\n--- 4. 高度な線形代数パッケージ (SparseArrays の例) ---")

S1 = sprand(100, 100, 0.01) # 疎行列を生成 (1%の要素が非ゼロ)
S2 = sprand(100, 100, 0.01)

# 疎行列の積は自動的に最適化される
S_result = S1 * S2
println("疎行列 S1 * S2 の非ゼロ要素数: ", nnz(S_result))
  • パフォーマンス
    適切な状況で使えば、通常のBLASよりも大幅に高速な計算が可能です。
  • 専門性
    特定の種類の行列(疎行列、特殊な構造を持つ行列など)や、特定のハードウェア(GPU)に最適化されています。
  • 特殊な行列(疎行列など)やGPU計算
    それぞれの目的に特化したパッケージを使用することを検討します。
  • 要素ごとの演算が必要な場合
    ブロードキャスト (.=.+) を使用します。
  • ほとんどの場合
    * 演算子(新しい行列を返す)または mul!()(in-placeで結果を上書きする)を使用します。
    • 計算結果を一時的にしか使わない場合や、新しいメモリ割り当てが問題にならない場合は * が便利です。
    • 大規模な行列を扱う場合、またはループ内で繰り返し行列積を計算する場合は、メモリ効率とパフォーマンスの観点から mul!() が強く推奨されます。