From e0fd019b9796ff881a75e72b0389fe56ec386080 Mon Sep 17 00:00:00 2001 From: Miguel Oliveira Date: Wed, 2 Mar 2022 17:29:26 -0300 Subject: [PATCH] Switch to older Fp code --- internal/fp.lua | 423 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 354 insertions(+), 69 deletions(-) diff --git a/internal/fp.lua b/internal/fp.lua index 361659a..201c177 100644 --- a/internal/fp.lua +++ b/internal/fp.lua @@ -1,9 +1,125 @@ -local unpack = unpack or table.unpack +--- Arithmetic on Curve25519's base field. +-- +-- @module internal.fq +-- +local unpack = unpack or table.unpack +local bxor = bit32.bxor +local band = bit32.band +local bor = bit32.bor + +--- The modular square root of -1. +local I = { + 0958640 * 2 ^ 0, + 0826664 * 2 ^ 22, + 1613251 * 2 ^ 43, + 1041528 * 2 ^ 64, + 0013673 * 2 ^ 85, + 0387171 * 2 ^ 107, + 1824679 * 2 ^ 128, + 0313839 * 2 ^ 149, + 0709440 * 2 ^ 170, + 0122635 * 2 ^ 192, + 0262782 * 2 ^ 213, + 0712905 * 2 ^ 234, +} + +--- The difference between a non-canonical number and its canonical equivalent. +local CDIFF = { + 2 ^ 22 - 19, + (2 ^ 21 - 1) * 2 ^ 22, + (2 ^ 21 - 1) * 2 ^ 43, + (2 ^ 21 - 1) * 2 ^ 64, + (2 ^ 22 - 1) * 2 ^ 85, + (2 ^ 21 - 1) * 2 ^ 107, + (2 ^ 21 - 1) * 2 ^ 128, + (2 ^ 21 - 1) * 2 ^ 149, + (2 ^ 22 - 1) * 2 ^ 170, + (2 ^ 21 - 1) * 2 ^ 192, + (2 ^ 21 - 1) * 2 ^ 213, + (2 ^ 21 - 1) * 2 ^ 234, +} + +--- A base field polynomial. +-- +-- The Curve25519 paper represents its numbers as "polynomals" that slice the +-- bigint into a little-endian array of floats. Each float slice is such that +-- the (infinite precision) sum of all of them is equal to the represented +-- number. +-- +-- For our implementation, we use an array of 12 floats. Each one has a specific +-- exponent and mantissa range. +-- +-- +-- +-- +-- +-- +-- +-- +-- +-- +-- +-- +-- +-- +-- +-- +-- +--
IndexCoefficient RangeMultiplier
0 (-2²²..2²²) 2⁰
1 (-2²¹..2²¹) 2²²
2 (-2²¹..2²¹) 2⁴³
3 (-2²¹..2²¹) 2⁶⁴
4 (-2²²..2²²) 2⁸⁵
5 (-2²¹..2²¹) 2¹⁰⁷
6 (-2²¹..2²¹) 2¹²⁸
7 (-2²¹..2²¹) 2¹⁴⁹
8 (-2²²..2²²) 2¹⁷⁰
9 (-2²¹..2²¹) 2¹⁹²
10 (-2²¹..2²¹) 2²¹³
11 (-2²¹..2²¹) 2²³⁴
+-- +-- @type fq +-- +local fq = nil +if fq ~= nil then return end + +--- A nonnegative @{fq}. +-- +-- This type represents elements that have no negative coefficients. +-- +-- @type fqAbs +-- +local fqAbs = nil +if fqAbs ~= nil then return end + +--- An uncarried @{fq}. +-- +-- This type represents elements that have coefficients in a wider range than +-- the limits specified in @{fq}. Specifically, this represents all the results +-- of uncarried float-wise additions of two elements. +-- +-- @type fqUncarried +-- +local fqUncarried = nil +if fqUncarried ~= nil then return end + +--- Converts a Lua number to an element. +-- +-- @tparam number n A number n in [0..2²²). +-- @treturn fqAbs n as a base field element. +-- local function num(n) return {n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} end +--- Adds two elements. +-- +-- @tparam fq a +-- @tparam fq b +-- @treturn fqUncarried +-- local function add(a, b) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) @@ -23,6 +139,42 @@ local function add(a, b) } end +--- Negates an element. +-- +-- @tparam fq a +-- @treturn fq +-- +local function neg(a) + local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) + return { + -a00, + -a01, + -a02, + -a03, + -a04, + -a05, + -a06, + -a07, + -a08, + -a09, + -a10, + -a11, + } +end + +--- Subtracts an element from another. +-- +-- If both elements are positive, then the result can be guaranteed to fit in +-- a single @{fq} without needing any carrying. +-- +-- @tparam[1] fq a +-- @tparam[1] fq b +-- @treturn[1] fqUncarried +-- +-- @tparam[2] fqAbs a +-- @tparam[2] fqAbs b +-- @treturn[2] fq +-- local function sub(a, b) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) @@ -42,50 +194,103 @@ local function sub(a, b) } end -local function kmul(a, k) +--- Carries an element. +-- +-- @tparam fqUncarried a +-- @treturn fqAbs +-- +local function carry(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) - local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 + local r00, r01, r02, r03, r04, r05, r06, r07, r08, r09, r10, r11 - -- Multiply. - c00 = a00 * k - c01 = a01 * k - c02 = a02 * k - c03 = a03 * k - c04 = a04 * k - c05 = a05 * k - c06 = a06 * k - c07 = a07 * k - c08 = a08 * k - c09 = a09 * k - c10 = a10 * k - c11 = a11 * k + r11 = a11 % 2 ^ 255 + a00 = a00 + (a11 - r11) * (19 / 2 ^ 255) - -- Carry and reduce. - a11 = c11 % 2 ^ 255 c00 = c00 + (c11 - a11) * (19 / 2 ^ 255) + r00 = a00 % 2 ^ 22 a01 = a01 + (a00 - r00) + r01 = a01 % 2 ^ 43 a02 = a02 + (a01 - r01) + r02 = a02 % 2 ^ 64 a03 = a03 + (a02 - r02) + r03 = a03 % 2 ^ 85 a04 = a04 + (a03 - r03) + r04 = a04 % 2 ^ 107 a05 = a05 + (a04 - r04) + r05 = a05 % 2 ^ 128 a06 = a06 + (a05 - r05) + r06 = a06 % 2 ^ 149 a07 = a07 + (a06 - r06) + r07 = a07 % 2 ^ 170 a08 = a08 + (a07 - r07) + r08 = a08 % 2 ^ 192 a09 = a09 + (a08 - r08) + r09 = a09 % 2 ^ 213 a10 = a10 + (a09 - r09) + r10 = a10 % 2 ^ 234 a11 = r11 + (a10 - r10) - a00 = c00 % 2 ^ 22 c01 = c01 + (c00 - a00) - a01 = c01 % 2 ^ 43 c02 = c02 + (c01 - a01) - a02 = c02 % 2 ^ 64 c03 = c03 + (c02 - a02) - a03 = c03 % 2 ^ 85 c04 = c04 + (c03 - a03) - a04 = c04 % 2 ^ 107 c05 = c05 + (c04 - a04) - a05 = c05 % 2 ^ 128 c06 = c06 + (c05 - a05) - a06 = c06 % 2 ^ 149 c07 = c07 + (c06 - a06) - a07 = c07 % 2 ^ 170 c08 = c08 + (c07 - a07) - a08 = c08 % 2 ^ 192 c09 = c09 + (c08 - a08) - a09 = c09 % 2 ^ 213 c10 = c10 + (c09 - a09) - a10 = c10 % 2 ^ 234 c11 = a11 + (c10 - a10) + r11 = a11 % 2 ^ 255 r00 = r00 + (a11 - r11) * (19 / 2 ^ 255) - a11 = c11 % 2 ^ 255 a00 = a00 + (c11 - a11) * (19 / 2 ^ 255) - - return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11} + return {r00, r01, r02, r03, r04, r05, r06, r07, r08, r09, r10, r11} end +--- Returns whether the modp number is the canonical representative. +-- +-- @see canonicalize +-- +-- @tparam fqAbs a +-- @treturn boolean +-- +local function isCanonical(a) + local e11 = bxor(a[12] * 2 ^ -234, 2 ^ 21 - 1) + local e10 = bxor(a[11] * 2 ^ -213, 2 ^ 21 - 1) + local e09 = bxor(a[10] * 2 ^ -192, 2 ^ 21 - 1) + local e08 = bxor(a[09] * 2 ^ -170, 2 ^ 22 - 1) + local e07 = bxor(a[08] * 2 ^ -149, 2 ^ 21 - 1) + local e06 = bxor(a[07] * 2 ^ -128, 2 ^ 21 - 1) + local e05 = bxor(a[06] * 2 ^ -107, 2 ^ 21 - 1) + local e04 = bxor(a[05] * 2 ^ -85, 2 ^ 22 - 1) + local e03 = bxor(a[04] * 2 ^ -64, 2 ^ 21 - 1) + local e02 = bxor(a[03] * 2 ^ -43, 2 ^ 21 - 1) + local e01 = bxor(a[02] * 2 ^ -22, 2 ^ 21 - 1) + local e00 = band(a[01] - (2 ^ 22 - 19), 2 ^ 31) + return 0 ~= bor(e00, e01, e02, e03, e04, e05, e06, e07, e08, e09, e10, e11) +end + +--- Returns the canoncal representative of a modp number. +-- +-- Some elements can be represented by two different arrays of floats. This +-- returns the canonical element of the represented equivalence class. We define +-- an element as canonical if it's the smallest nonnegative number in its class. +-- +-- @tparam fq a +-- @treturn fqAbs +-- +local function canonicalize(a) + a = carry(a) + local zero = num(0) + local diff = isCanonical(a) and zero or CDIFF + return sub(a, diff) +end + +--- Returns whether two elements are the same. +-- +-- @tparam fqAbs a +-- @tparam fqAbs b +-- @treturn boolean +-- +local function eq(a, b) + a = canonicalize(a) + b = canonicalize(b) + for i = 1, 12 do + if a[i] ~= b[i] then + return false + end + end + return true +end + +--- Multiplies two elements. +-- +-- @tparam fqUncarried a +-- @tparam fqUncarried b +-- @treturn fqAbs +-- local function mul(a, b) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 - -- Multiply high half. + -- Multiply high half into c00..c11. c00 = a11 * b01 + a10 * b02 + a09 * b03 @@ -153,7 +358,7 @@ local function mul(a, b) + a10 * b11 c10 = a11 * b11 - -- Multiply low half with reduction. + -- Multiply low half with reduction into c00..c11. c00 = c00 * (19 / 2 ^ 255) + a00 * b00 c01 = c01 * (19 / 2 ^ 255) @@ -265,6 +470,11 @@ local function mul(a, b) return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11} end +--- Squares an element. +-- +-- @tparam fqUncarried a +-- @treturn fqAbs +-- local function square(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local d00, d01, d02, d03, d04, d05, d06, d07, d08, d09, d10 @@ -283,7 +493,7 @@ local function square(a) d09 = a09 + a09 d10 = a10 + a10 - -- Multiply high half. + -- Multiply high half into c00..c11. c00 = a11 * d01 + a10 * d02 + a09 * d03 @@ -321,7 +531,7 @@ local function square(a) c09 = a11 * d10 c10 = a11 * a11 - -- Multiply low half with reduction. + -- Multiply low half with reduction into c00..c11. c00 = c00 * (19 / 2 ^ 255) + a00 * a00 c01 = c01 * (19 / 2 ^ 255) @@ -397,15 +607,56 @@ local function square(a) return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11} end +--- Multiplies an element by a number. +-- +-- @tparam fqUncarried a +-- @tparam number k A number k in [0..2²¹). +-- @treturn fqAbs +-- +local function kmul(a, k) + local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) + + -- TODO WHY ARE TYPE CONSTRAINTS SO DIFFICULT TO SPECIFY + return carry { + a00 * k, + a01 * k, + a02 * k, + a03 * k, + a04 * k, + a05 * k, + a06 * k, + a07 * k, + a08 * k, + a09 * k, + a10 * k, + a11 * k, + } +end + +--- Squares a modp number n times. +-- +-- @tparam fqUncarried a +-- @tparam number n +-- @treturn fqAbs +-- local function nsquare(a, n) for _ = 1, n do a = square(a) end return a end +--- Computes the inverse of an element. +-- +-- Computation of the inverse requires 11 multiplicationss and 252 squarings. +-- +-- @tparam fqUncarried a +-- @treturn[1] fqAbs a⁻¹ +-- @treturn[2] fqAbs 0 if the argument is 0, which has no inverse. +-- local function invert(a) local a2 = square(a) local a9 = mul(a, nsquare(a2, 2)) local a11 = mul(a9, a2) + local x5 = mul(square(a11), a9) local x10 = mul(nsquare(x5, 5), x5) local x20 = mul(nsquare(x10, 10), x10) @@ -414,44 +665,71 @@ local function invert(a) local x100 = mul(nsquare(x50, 50), x50) local x200 = mul(nsquare(x100, 100), x100) local x250 = mul(nsquare(x200, 50), x50) + return mul(nsquare(x250, 5), a11) end -local function encode(a) - local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) +--- Returns an element x that satisfies v * x² = u. +-- +-- Note that when v = 0, the returned value can take any @{fqAbs} value. +-- +-- @tparam fqUncarried u +-- @tparam fqUncarried v +-- @treturn[1] fqAbs x +-- @treturn[2] nil if there is no solution. +-- +local function sqrtDiv(u, v) + u = carry(u) - -- Canonicalize. - if a11 == (2 ^ 21 - 1) * 2 ^ 234 - and a10 == (2 ^ 21 - 1) * 2 ^ 213 - and a09 == (2 ^ 21 - 1) * 2 ^ 192 - and a08 == (2 ^ 22 - 1) * 2 ^ 170 - and a07 == (2 ^ 21 - 1) * 2 ^ 149 - and a06 == (2 ^ 21 - 1) * 2 ^ 128 - and a05 == (2 ^ 21 - 1) * 2 ^ 107 - and a04 == (2 ^ 22 - 1) * 2 ^ 85 - and a03 == (2 ^ 21 - 1) * 2 ^ 64 - and a02 == (2 ^ 21 - 1) * 2 ^ 43 - and a01 == (2 ^ 21 - 1) * 2 ^ 22 - and a00 >= 2 ^ 22 - 19 - then - a11 = 0 - a10 = 0 - a09 = 0 - a08 = 0 - a07 = 0 - a06 = 0 - a05 = 0 - a04 = 0 - a03 = 0 - a02 = 0 - a01 = 0 - a00 = a00 - (2 ^ 22 - 19) + local v2 = square(v) + local v3 = mul(v, v2) + local v6 = square(v3) + local v7 = mul(v, v6) + local uv7 = mul(u, v7) + + local x2 = mul(square(uv7), uv7) + local x4 = mul(nsquare(x2, 2), x2) + local x8 = mul(nsquare(x4, 4), x4) + local x16 = mul(nsquare(x8, 8), x8) + local x18 = mul(nsquare(x16, 2), x2) + local x32 = mul(nsquare(x16, 16), x16) + local x50 = mul(nsquare(x32, 18), x18) + local x100 = mul(nsquare(x50, 50), x50) + local x200 = mul(nsquare(x100, 100), x100) + local x250 = mul(nsquare(x200, 50), x50) + local pr = mul(nsquare(x250, 2), uv7) + + local uv3 = mul(u, v3) + local b = mul(uv3, pr) + local b2 = square(b) + local vb2 = mul(v, b2) + + if not eq(vb2, u) then + -- Found sqrt(-u/v), multiply by i. + b = mul(b, I) + b2 = square(b) + vb2 = mul(v, b2) end - -- Encode. - -- TODO this can be improved. + if eq(vb2, u) then + return b + else + return nil + end +end + +--- Encodes an element in little-endian. +-- +-- @tparam fqAbs a +-- @treturn string A 32-byte string. Always represents the canonical element. +-- +local function encode(a) + a = canonicalize(a) + local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) + local bytes = {} local acc = a00 + local function putBytes(n) for _ = 1, n do local byte = acc % 256 @@ -476,13 +754,18 @@ local function encode(a) return string.char(unpack(bytes)) end +--- Decodes an element in little-endian. +-- +-- @tparam string b A 32-byte string. The most-significant bit is discarded. +-- @treturn fqAbs The decoded element. May not be canonical. +-- local function decode(b) local w00, w01, w02, w03, w04, w05, w06, w07, w08, w09, w10, w11 = ("