SINQ (SInkhorn Normalized Quantization) とは、Huaweiが2025年に発表した新しいオープンソース行列量子化アルゴリズムです。大規模言語モデル(LLM)の重みを低精度に変換しつつ、性能をほとんど落とさずにメモリ使用量を大幅に削減できるのが特徴です。

特徴 編集

  • デュアルスケーリング:行方向と列方向に別々のスケールを導入し、外れ値の影響を分散。
  • Sinkhorn型正規化:行・列を交互に正規化して分散を均一化し、量子化誤差を抑制。
  • 低ビット量子化:3〜4bitでも高精度を維持可能。
  • 効果:メモリ使用量を60〜70%削減し、安価なGPUや低性能ハードでも大規模モデルを動作可能に。

擬似コード 編集

人工知能以外でも使えそうなので調べてみた。

// SINQ量子化の擬似コード(C#風)

class SinqQuantizer
{
    public int BitWidth { get; set; } = 4;   // 量子化ビット幅(例: 3bit, 4bit)
    public int SinkhornIters { get; set; } = 20; // Sinkhorn反復回数
    public float Eps { get; set; } = 1e-6f;  // 数値安定用

    // 重み行列Wを量子化
    public QuantizedMatrix Quantize(float[,] W)
    {
        // 1. 行・列スケールをSinkhorn正規化で計算
        var (rowScale, colScale) = ComputeDualScaling(W);

        // 2. スケーリングを適用
        float[,] S = ApplyScaling(W, rowScale, colScale);

        // 3. ブロックごとに量子化(例: 行単位)
        var blocks = PartitionByRow(S);
        var qBlocks = new List<QuantizedBlock>();

        foreach (var block in blocks)
        {
            var stats = ComputeRange(block);       // min/maxやscaleを計算
            var qBlock = QuantizeBlock(block, stats, BitWidth);
            qBlocks.Add(qBlock);
        }

        return new QuantizedMatrix
        {
            RowScale = rowScale,
            ColScale = colScale,
            Blocks = qBlocks,
            BitWidth = BitWidth
        };
    }

    // Sinkhorn型の行列正規化
    private (float[] rowScale, float[] colScale) ComputeDualScaling(float[,] W)
    {
        int R = W.GetLength(0);
        int C = W.GetLength(1);
        float[] r = Enumerable.Repeat(1f, R).ToArray();
        float[] c = Enumerable.Repeat(1f, C).ToArray();

        for (int t = 0; t < SinkhornIters; t++)
        {
            // 行方向の正規化
            for (int i = 0; i < R; i++)
            {
                float variance = RowVariance(W, i, c);
                r[i] = 1.0f / (float)Math.Sqrt(variance + Eps);
            }

            // 列方向の正規化
            for (int j = 0; j < C; j++)
            {
                float variance = ColVariance(W, j, r);
                c[j] = 1.0f / (float)Math.Sqrt(variance + Eps);
            }
        }
        return (r, c);
    }

    // スケーリング適用
    private float[,] ApplyScaling(float[,] W, float[] r, float[] c)
    {
        int R = W.GetLength(0);
        int C = W.GetLength(1);
        float[,] S = new float[R, C];
        for (int i = 0; i < R; i++)
        for (int j = 0; j < C; j++)
            S[i, j] = r[i] * W[i, j] * c[j];
        return S;
    }

    // ブロック分割(例: 行単位)
    private IEnumerable<float[]> PartitionByRow(float[,] S)
    {
        int R = S.GetLength(0);
        int C = S.GetLength(1);
        for (int i = 0; i < R; i++)
        {
            float[] row = new float[C];
            for (int j = 0; j < C; j++) row[j] = S[i, j];
            yield return row;
        }
    }

    // 範囲計算(min/max)
    private (float min, float max, float scale) ComputeRange(float[] block)
    {
        float min = block.Min();
        float max = block.Max();
        float scale = (max - min) / ((1 << BitWidth) - 1);
        return (min, max, scale);
    }

    // ブロック量子化
    private QuantizedBlock QuantizeBlock(float[] block, (float min, float max, float scale) stats, int bits)
    {
        int n = block.Length;
        byte[] qVals = new byte[n];
        for (int i = 0; i < n; i++)
        {
            int q = (int)Math.Round((block[i] - stats.min) / stats.scale);
            qVals[i] = (byte)Math.Clamp(q, 0, (1 << bits) - 1);
        }
        return new QuantizedBlock { Values = qVals, Stats = stats };
    }

    // 行方向の分散
    private float RowVariance(float[,] W, int row, float[] colScale)
    {
        int C = W.GetLength(1);
        float mean = 0f;
        for (int j = 0; j < C; j++)
            mean += W[row, j] * colScale[j];
        mean /= C;

        float var = 0f;
        for (int j = 0; j < C; j++)
        {
            float val = W[row, j] * colScale[j] - mean;
            var += val * val;
        }
        return var / C;
    }

    // 列方向の分散
    private float ColVariance(float[,] W, int col, float[] rowScale)
    {
        int R = W.GetLength(0);
        float mean = 0f;
        for (int i = 0; i < R; i++)
            mean += rowScale[i] * W[i, col];
        mean /= R;

        float var = 0f;
        for (int i = 0; i < R; i++)
        {
            float val = rowScale[i] * W[i, col] - mean;
            var += val * val;
        }
        return var / R;
    }
}

// データ構造
class QuantizedMatrix
{
    public float[] RowScale;
    public float[] ColScale;
    public List<QuantizedBlock> Blocks;
    public int BitWidth;
}

class QuantizedBlock
{
    public byte[] Values;
    public (float min, float max, float scale) Stats;
}

外部リンク 編集