diff --git a/internal/fq.lua b/internal/fq.lua index b0ca25c..c99f977 100644 --- a/internal/fq.lua +++ b/internal/fq.lua @@ -10,6 +10,7 @@ -- @module[kind=internal] internal.fq -- +local mp = require "ccrytpolib.internal.mp" local util = require "ccryptolib.internal.util" local unpack = unpack or table.unpack @@ -59,63 +60,6 @@ local T1 = { 00000283, } ---- Carries a number in base 2²⁴. --- --- @tparam {number...} a A number 0 <= a < 2 ^ (24 × (#a + 1)) as limbs in --- [-2⁵²..2⁵²]. --- @treturn {number...} a as #a + 1 limbs in [0..2²⁴). --- -local function carry(a) - local c = {unpack(a)} - c[#c + 1] = 0 - for i = 1, #c - 1 do - local val = c[i] - local rem = val % 2 ^ 24 - local quot = (val - rem) / 2 ^ 24 - c[i + 1] = c[i + 1] + quot - c[i] = rem - end - return c -end - ---- Adds two numbers. --- --- @tparam {number...} a An array limbs in [0..2²⁴). --- @tparam {number...} b An array of #a limbs in [0..2²⁴). --- @treturn {number...} a + b as #a + 1 limbs in [0..2²⁴). --- -local function intAdd(a, b) - local c = {} - for i = 1, #a do - c[i] = a[i] + b[i] - end - - -- c's limbs fit in [-2²⁵..2²⁵], since addition adds at most one bit. - return carry(c) -end - ---- Multiplies two numbers. --- --- @tparam {number...} a An array of 11 limbs in [0..2²⁴). --- @tparam {number...} b An array of 11 limbs in [0..2²⁴). --- @treturn {number...} a × b as 22 limbs in [0..2²⁴). --- -local function intMul(a, b) - local c = {} - for i = 1, 21 do c[i] = 0 end - for i = 1, 11 do - for j = 1, 11 do - local k = i + j - 1 - c[k] = c[k] + a[i] * b[j] - end - end - - -- {a, b} < 2²⁶⁴ means that c < 2⁵²⁸ = 2 ^ (24 × (21 + 1)). - -- c's limbs are smaller than 2⁴⁸ × 11 < 2⁵², since multiplication doubles - -- bit length, and 11 multiplied limbs are added together. - return carry(c) -end - --- Reduces a number modulo q. -- -- @tparam {number...} a A number a < 2q as 12 limbs in [0..2²⁴). @@ -140,7 +84,7 @@ local function reduce(a) -- c >= q means c - q >= 0. -- Since q < 2²⁸⁸, c < 2q means c - q < q < 2²⁸⁸ = 2^(24 × (11 + 1)). -- c's limbs fit in [-2²⁵..2²⁵], since subtraction adds at most one bit. - local cc = carry(c) + local cc = mp.carry(c) cc[12] = nil -- cc < q implies that cc[12] = 0. return cc end @@ -155,7 +99,7 @@ end -- @treturn {number...} a + b mod q as 11 limbs in [0..2²⁴). -- local function add(a, b) - return reduce(intAdd(a, b)) + return reduce(mp.add(a, b)) end --- Negates a scalar mod q. @@ -172,7 +116,7 @@ local function neg(a) -- 0 < c < q implies 0 < q - c < q < 2²⁸⁸ = 2^(24 × (11 + 1)). -- c's limbs fit in [-2²⁵..2²⁵], since subtraction adds at most one bit. -- q - c < q also implies q - c < 2q. - return reduce(carry(c)) + return reduce(mp.carry(c)) end --- Given a scalar a, computes 2⁻²⁶⁴ a mod q. @@ -182,10 +126,10 @@ end -- local function redc(a) local al = {unpack(a, 1, 11)} - local mm = intMul(al, T0) + local mm = mp.mul(al, T0) local m = {unpack(mm, 1, 11)} - local mr = intMul(m, Q) - local t = intAdd(a, mr) + local mr = mp.mul(m, Q) + local t = mp.add(a, mr) return reduce({unpack(t, 12, 23)}) end @@ -196,7 +140,7 @@ end -- local function montgomery(a) -- a < 2²⁶⁴ and T1 < q imply that a × T1 < 2²⁶⁴ × q. - return redc(intMul(a, T1)) + return redc(mp.mul(a, T1)) end --- Converts a scalar from Montgomery form. @@ -228,7 +172,7 @@ end -- local function mul(a, b) -- {a, b} < q so a × b < q² < 2²⁶⁴ × q. - return redc(intMul(a, b)) + return redc(mp.mul(a, b)) end --- Encodes a scalar. diff --git a/internal/mp.lua b/internal/mp.lua new file mode 100644 index 0000000..6cd6c6f --- /dev/null +++ b/internal/mp.lua @@ -0,0 +1,74 @@ +--- Multi-precision arithmetic on 264-bit integers. +-- +-- :::note Internal Module +-- This module is meant for internal use within the library. Its API is unstable +-- and subject to change without major version bumps. +-- ::: +-- +--
+-- +-- @module[kind=internal] internal.mp +-- + +--- Carries a number in base 2²⁴. +-- +-- @tparam {number...} a A number 0 <= a < 2 ^ (24 × (#a + 1)) as limbs in +-- [-2⁵²..2⁵²]. +-- @treturn {number...} a as #a + 1 limbs in [0..2²⁴). +-- +local function carry(a) + local c = {unpack(a)} + c[#c + 1] = 0 + for i = 1, #c - 1 do + local val = c[i] + local rem = val % 2 ^ 24 + local quot = (val - rem) / 2 ^ 24 + c[i + 1] = c[i + 1] + quot + c[i] = rem + end + return c +end + +--- Adds two numbers. +-- +-- @tparam {number...} a An array limbs in [0..2²⁴). +-- @tparam {number...} b An array of #a limbs in [0..2²⁴). +-- @treturn {number...} a + b as #a + 1 limbs in [0..2²⁴). +-- +local function add(a, b) + local c = {} + for i = 1, #a do + c[i] = a[i] + b[i] + end + + -- c's limbs fit in [-2²⁵..2²⁵], since addition adds at most one bit. + return carry(c) +end + +--- Multiplies two numbers. +-- +-- @tparam {number...} a An array of 11 limbs in [0..2²⁴). +-- @tparam {number...} b An array of 11 limbs in [0..2²⁴). +-- @treturn {number...} a × b as 22 limbs in [0..2²⁴). +-- +local function mul(a, b) + local c = {} + for i = 1, 21 do c[i] = 0 end + for i = 1, 11 do + for j = 1, 11 do + local k = i + j - 1 + c[k] = c[k] + a[i] * b[j] + end + end + + -- {a, b} < 2²⁶⁴ means that c < 2⁵²⁸ = 2 ^ (24 × (21 + 1)). + -- c's limbs are smaller than 2⁴⁸ × 11 < 2⁵², since multiplication doubles + -- bit length, and 11 multiplied limbs are added together. + return carry(c) +end + +return { + carry = carry, + add = add, + mul = mul, +}