Switch to older Fp code

This commit is contained in:
Miguel Oliveira 2022-03-02 17:29:26 -03:00
parent 59647d1a96
commit e0fd019b97
No known key found for this signature in database
GPG key ID: 2C2BE789E1377025

View file

@ -1,9 +1,125 @@
local unpack = unpack or table.unpack --- Arithmetic on Curve25519's base field.
--
-- @module internal.fq
--
local unpack = unpack or table.unpack
local bxor = bit32.bxor
local band = bit32.band
local bor = bit32.bor
--- The modular square root of -1.
local I = {
0958640 * 2 ^ 0,
0826664 * 2 ^ 22,
1613251 * 2 ^ 43,
1041528 * 2 ^ 64,
0013673 * 2 ^ 85,
0387171 * 2 ^ 107,
1824679 * 2 ^ 128,
0313839 * 2 ^ 149,
0709440 * 2 ^ 170,
0122635 * 2 ^ 192,
0262782 * 2 ^ 213,
0712905 * 2 ^ 234,
}
--- The difference between a non-canonical number and its canonical equivalent.
local CDIFF = {
2 ^ 22 - 19,
(2 ^ 21 - 1) * 2 ^ 22,
(2 ^ 21 - 1) * 2 ^ 43,
(2 ^ 21 - 1) * 2 ^ 64,
(2 ^ 22 - 1) * 2 ^ 85,
(2 ^ 21 - 1) * 2 ^ 107,
(2 ^ 21 - 1) * 2 ^ 128,
(2 ^ 21 - 1) * 2 ^ 149,
(2 ^ 22 - 1) * 2 ^ 170,
(2 ^ 21 - 1) * 2 ^ 192,
(2 ^ 21 - 1) * 2 ^ 213,
(2 ^ 21 - 1) * 2 ^ 234,
}
--- A base field polynomial.
--
-- The Curve25519 paper represents its numbers as "polynomals" that slice the
-- bigint into a little-endian array of floats. Each float slice is such that
-- the (infinite precision) sum of all of them is equal to the represented
-- number.
--
-- For our implementation, we use an array of 12 floats. Each one has a specific
-- exponent and mantissa range.
--
-- <!-- My best wishes to whoever is doing the Markdown parsing. -->
-- <style>
-- table.mdt {
-- border-collapse: collapse;
-- }
-- table.mdt td, table.mdt th {
-- border: 1px solid #cccccc;
-- padding: 5px;
-- text-align: center;
-- }
-- table.mdt th {
-- background-color: var(--background-2);
-- }
-- </style>
-- <table class="mdt">
-- <tr><th>Index</th><th>Coefficient Range</th><th>Multiplier</th></tr>
-- <tr><td> 0 </td><td> (-2²²..2²²) </td><td> 2⁰ </td></tr>
-- <tr><td> 1 </td><td> (-2²¹..2²¹) </td><td> 2²² </td></tr>
-- <tr><td> 2 </td><td> (-2²¹..2²¹) </td><td> 2⁴³ </td></tr>
-- <tr><td> 3 </td><td> (-2²¹..2²¹) </td><td> 2⁶⁴ </td></tr>
-- <tr><td> 4 </td><td> (-2²²..2²²) </td><td> 2⁸⁵ </td></tr>
-- <tr><td> 5 </td><td> (-2²¹..2²¹) </td><td> 2¹⁰⁷ </td></tr>
-- <tr><td> 6 </td><td> (-2²¹..2²¹) </td><td> 2¹²⁸ </td></tr>
-- <tr><td> 7 </td><td> (-2²¹..2²¹) </td><td> 2¹⁴⁹ </td></tr>
-- <tr><td> 8 </td><td> (-2²²..2²²) </td><td> 2¹⁷⁰ </td></tr>
-- <tr><td> 9 </td><td> (-2²¹..2²¹) </td><td> 2¹⁹² </td></tr>
-- <tr><td> 10 </td><td> (-2²¹..2²¹) </td><td> 2²¹³ </td></tr>
-- <tr><td> 11 </td><td> (-2²¹..2²¹) </td><td> 2²³⁴ </td></tr>
-- </table>
--
-- @type fq
--
local fq = nil
if fq ~= nil then return end
--- A nonnegative @{fq}.
--
-- This type represents elements that have no negative coefficients.
--
-- @type fqAbs
--
local fqAbs = nil
if fqAbs ~= nil then return end
--- An uncarried @{fq}.
--
-- This type represents elements that have coefficients in a wider range than
-- the limits specified in @{fq}. Specifically, this represents all the results
-- of uncarried float-wise additions of two elements.
--
-- @type fqUncarried
--
local fqUncarried = nil
if fqUncarried ~= nil then return end
--- Converts a Lua number to an element.
--
-- @tparam number n A number n in [0..2²²).
-- @treturn fqAbs n as a base field element.
--
local function num(n) local function num(n)
return {n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} return {n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
end end
--- Adds two elements.
--
-- @tparam fq a
-- @tparam fq b
-- @treturn fqUncarried
--
local function add(a, b) local function add(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
@ -23,6 +139,42 @@ local function add(a, b)
} }
end end
--- Negates an element.
--
-- @tparam fq a
-- @treturn fq
--
local function neg(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
return {
-a00,
-a01,
-a02,
-a03,
-a04,
-a05,
-a06,
-a07,
-a08,
-a09,
-a10,
-a11,
}
end
--- Subtracts an element from another.
--
-- If both elements are positive, then the result can be guaranteed to fit in
-- a single @{fq} without needing any carrying.
--
-- @tparam[1] fq a
-- @tparam[1] fq b
-- @treturn[1] fqUncarried
--
-- @tparam[2] fqAbs a
-- @tparam[2] fqAbs b
-- @treturn[2] fq
--
local function sub(a, b) local function sub(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
@ -42,50 +194,103 @@ local function sub(a, b)
} }
end end
local function kmul(a, k) --- Carries an element.
--
-- @tparam fqUncarried a
-- @treturn fqAbs
--
local function carry(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 local r00, r01, r02, r03, r04, r05, r06, r07, r08, r09, r10, r11
-- Multiply. r11 = a11 % 2 ^ 255
c00 = a00 * k a00 = a00 + (a11 - r11) * (19 / 2 ^ 255)
c01 = a01 * k
c02 = a02 * k
c03 = a03 * k
c04 = a04 * k
c05 = a05 * k
c06 = a06 * k
c07 = a07 * k
c08 = a08 * k
c09 = a09 * k
c10 = a10 * k
c11 = a11 * k
-- Carry and reduce. r00 = a00 % 2 ^ 22 a01 = a01 + (a00 - r00)
a11 = c11 % 2 ^ 255 c00 = c00 + (c11 - a11) * (19 / 2 ^ 255) r01 = a01 % 2 ^ 43 a02 = a02 + (a01 - r01)
r02 = a02 % 2 ^ 64 a03 = a03 + (a02 - r02)
r03 = a03 % 2 ^ 85 a04 = a04 + (a03 - r03)
r04 = a04 % 2 ^ 107 a05 = a05 + (a04 - r04)
r05 = a05 % 2 ^ 128 a06 = a06 + (a05 - r05)
r06 = a06 % 2 ^ 149 a07 = a07 + (a06 - r06)
r07 = a07 % 2 ^ 170 a08 = a08 + (a07 - r07)
r08 = a08 % 2 ^ 192 a09 = a09 + (a08 - r08)
r09 = a09 % 2 ^ 213 a10 = a10 + (a09 - r09)
r10 = a10 % 2 ^ 234 a11 = r11 + (a10 - r10)
a00 = c00 % 2 ^ 22 c01 = c01 + (c00 - a00) r11 = a11 % 2 ^ 255 r00 = r00 + (a11 - r11) * (19 / 2 ^ 255)
a01 = c01 % 2 ^ 43 c02 = c02 + (c01 - a01)
a02 = c02 % 2 ^ 64 c03 = c03 + (c02 - a02)
a03 = c03 % 2 ^ 85 c04 = c04 + (c03 - a03)
a04 = c04 % 2 ^ 107 c05 = c05 + (c04 - a04)
a05 = c05 % 2 ^ 128 c06 = c06 + (c05 - a05)
a06 = c06 % 2 ^ 149 c07 = c07 + (c06 - a06)
a07 = c07 % 2 ^ 170 c08 = c08 + (c07 - a07)
a08 = c08 % 2 ^ 192 c09 = c09 + (c08 - a08)
a09 = c09 % 2 ^ 213 c10 = c10 + (c09 - a09)
a10 = c10 % 2 ^ 234 c11 = a11 + (c10 - a10)
a11 = c11 % 2 ^ 255 a00 = a00 + (c11 - a11) * (19 / 2 ^ 255) return {r00, r01, r02, r03, r04, r05, r06, r07, r08, r09, r10, r11}
return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11}
end end
--- Returns whether the modp number is the canonical representative.
--
-- @see canonicalize
--
-- @tparam fqAbs a
-- @treturn boolean
--
local function isCanonical(a)
local e11 = bxor(a[12] * 2 ^ -234, 2 ^ 21 - 1)
local e10 = bxor(a[11] * 2 ^ -213, 2 ^ 21 - 1)
local e09 = bxor(a[10] * 2 ^ -192, 2 ^ 21 - 1)
local e08 = bxor(a[09] * 2 ^ -170, 2 ^ 22 - 1)
local e07 = bxor(a[08] * 2 ^ -149, 2 ^ 21 - 1)
local e06 = bxor(a[07] * 2 ^ -128, 2 ^ 21 - 1)
local e05 = bxor(a[06] * 2 ^ -107, 2 ^ 21 - 1)
local e04 = bxor(a[05] * 2 ^ -85, 2 ^ 22 - 1)
local e03 = bxor(a[04] * 2 ^ -64, 2 ^ 21 - 1)
local e02 = bxor(a[03] * 2 ^ -43, 2 ^ 21 - 1)
local e01 = bxor(a[02] * 2 ^ -22, 2 ^ 21 - 1)
local e00 = band(a[01] - (2 ^ 22 - 19), 2 ^ 31)
return 0 ~= bor(e00, e01, e02, e03, e04, e05, e06, e07, e08, e09, e10, e11)
end
--- Returns the canoncal representative of a modp number.
--
-- Some elements can be represented by two different arrays of floats. This
-- returns the canonical element of the represented equivalence class. We define
-- an element as canonical if it's the smallest nonnegative number in its class.
--
-- @tparam fq a
-- @treturn fqAbs
--
local function canonicalize(a)
a = carry(a)
local zero = num(0)
local diff = isCanonical(a) and zero or CDIFF
return sub(a, diff)
end
--- Returns whether two elements are the same.
--
-- @tparam fqAbs a
-- @tparam fqAbs b
-- @treturn boolean
--
local function eq(a, b)
a = canonicalize(a)
b = canonicalize(b)
for i = 1, 12 do
if a[i] ~= b[i] then
return false
end
end
return true
end
--- Multiplies two elements.
--
-- @tparam fqUncarried a
-- @tparam fqUncarried b
-- @treturn fqAbs
--
local function mul(a, b) local function mul(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11
-- Multiply high half. -- Multiply high half into c00..c11.
c00 = a11 * b01 c00 = a11 * b01
+ a10 * b02 + a10 * b02
+ a09 * b03 + a09 * b03
@ -153,7 +358,7 @@ local function mul(a, b)
+ a10 * b11 + a10 * b11
c10 = a11 * b11 c10 = a11 * b11
-- Multiply low half with reduction. -- Multiply low half with reduction into c00..c11.
c00 = c00 * (19 / 2 ^ 255) c00 = c00 * (19 / 2 ^ 255)
+ a00 * b00 + a00 * b00
c01 = c01 * (19 / 2 ^ 255) c01 = c01 * (19 / 2 ^ 255)
@ -265,6 +470,11 @@ local function mul(a, b)
return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11} return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11}
end end
--- Squares an element.
--
-- @tparam fqUncarried a
-- @treturn fqAbs
--
local function square(a) local function square(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local d00, d01, d02, d03, d04, d05, d06, d07, d08, d09, d10 local d00, d01, d02, d03, d04, d05, d06, d07, d08, d09, d10
@ -283,7 +493,7 @@ local function square(a)
d09 = a09 + a09 d09 = a09 + a09
d10 = a10 + a10 d10 = a10 + a10
-- Multiply high half. -- Multiply high half into c00..c11.
c00 = a11 * d01 c00 = a11 * d01
+ a10 * d02 + a10 * d02
+ a09 * d03 + a09 * d03
@ -321,7 +531,7 @@ local function square(a)
c09 = a11 * d10 c09 = a11 * d10
c10 = a11 * a11 c10 = a11 * a11
-- Multiply low half with reduction. -- Multiply low half with reduction into c00..c11.
c00 = c00 * (19 / 2 ^ 255) c00 = c00 * (19 / 2 ^ 255)
+ a00 * a00 + a00 * a00
c01 = c01 * (19 / 2 ^ 255) c01 = c01 * (19 / 2 ^ 255)
@ -397,15 +607,56 @@ local function square(a)
return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11} return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11}
end end
--- Multiplies an element by a number.
--
-- @tparam fqUncarried a
-- @tparam number k A number k in [0..2²¹).
-- @treturn fqAbs
--
local function kmul(a, k)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
-- TODO WHY ARE TYPE CONSTRAINTS SO DIFFICULT TO SPECIFY
return carry {
a00 * k,
a01 * k,
a02 * k,
a03 * k,
a04 * k,
a05 * k,
a06 * k,
a07 * k,
a08 * k,
a09 * k,
a10 * k,
a11 * k,
}
end
--- Squares a modp number n times.
--
-- @tparam fqUncarried a
-- @tparam number n
-- @treturn fqAbs
--
local function nsquare(a, n) local function nsquare(a, n)
for _ = 1, n do a = square(a) end for _ = 1, n do a = square(a) end
return a return a
end end
--- Computes the inverse of an element.
--
-- Computation of the inverse requires 11 multiplicationss and 252 squarings.
--
-- @tparam fqUncarried a
-- @treturn[1] fqAbs a⁻¹
-- @treturn[2] fqAbs 0 if the argument is 0, which has no inverse.
--
local function invert(a) local function invert(a)
local a2 = square(a) local a2 = square(a)
local a9 = mul(a, nsquare(a2, 2)) local a9 = mul(a, nsquare(a2, 2))
local a11 = mul(a9, a2) local a11 = mul(a9, a2)
local x5 = mul(square(a11), a9) local x5 = mul(square(a11), a9)
local x10 = mul(nsquare(x5, 5), x5) local x10 = mul(nsquare(x5, 5), x5)
local x20 = mul(nsquare(x10, 10), x10) local x20 = mul(nsquare(x10, 10), x10)
@ -414,44 +665,71 @@ local function invert(a)
local x100 = mul(nsquare(x50, 50), x50) local x100 = mul(nsquare(x50, 50), x50)
local x200 = mul(nsquare(x100, 100), x100) local x200 = mul(nsquare(x100, 100), x100)
local x250 = mul(nsquare(x200, 50), x50) local x250 = mul(nsquare(x200, 50), x50)
return mul(nsquare(x250, 5), a11) return mul(nsquare(x250, 5), a11)
end end
local function encode(a) --- Returns an element x that satisfies v * x² = u.
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) --
-- Note that when v = 0, the returned value can take any @{fqAbs} value.
--
-- @tparam fqUncarried u
-- @tparam fqUncarried v
-- @treturn[1] fqAbs x
-- @treturn[2] nil if there is no solution.
--
local function sqrtDiv(u, v)
u = carry(u)
-- Canonicalize. local v2 = square(v)
if a11 == (2 ^ 21 - 1) * 2 ^ 234 local v3 = mul(v, v2)
and a10 == (2 ^ 21 - 1) * 2 ^ 213 local v6 = square(v3)
and a09 == (2 ^ 21 - 1) * 2 ^ 192 local v7 = mul(v, v6)
and a08 == (2 ^ 22 - 1) * 2 ^ 170 local uv7 = mul(u, v7)
and a07 == (2 ^ 21 - 1) * 2 ^ 149
and a06 == (2 ^ 21 - 1) * 2 ^ 128 local x2 = mul(square(uv7), uv7)
and a05 == (2 ^ 21 - 1) * 2 ^ 107 local x4 = mul(nsquare(x2, 2), x2)
and a04 == (2 ^ 22 - 1) * 2 ^ 85 local x8 = mul(nsquare(x4, 4), x4)
and a03 == (2 ^ 21 - 1) * 2 ^ 64 local x16 = mul(nsquare(x8, 8), x8)
and a02 == (2 ^ 21 - 1) * 2 ^ 43 local x18 = mul(nsquare(x16, 2), x2)
and a01 == (2 ^ 21 - 1) * 2 ^ 22 local x32 = mul(nsquare(x16, 16), x16)
and a00 >= 2 ^ 22 - 19 local x50 = mul(nsquare(x32, 18), x18)
then local x100 = mul(nsquare(x50, 50), x50)
a11 = 0 local x200 = mul(nsquare(x100, 100), x100)
a10 = 0 local x250 = mul(nsquare(x200, 50), x50)
a09 = 0 local pr = mul(nsquare(x250, 2), uv7)
a08 = 0
a07 = 0 local uv3 = mul(u, v3)
a06 = 0 local b = mul(uv3, pr)
a05 = 0 local b2 = square(b)
a04 = 0 local vb2 = mul(v, b2)
a03 = 0
a02 = 0 if not eq(vb2, u) then
a01 = 0 -- Found sqrt(-u/v), multiply by i.
a00 = a00 - (2 ^ 22 - 19) b = mul(b, I)
b2 = square(b)
vb2 = mul(v, b2)
end end
-- Encode. if eq(vb2, u) then
-- TODO this can be improved. return b
else
return nil
end
end
--- Encodes an element in little-endian.
--
-- @tparam fqAbs a
-- @treturn string A 32-byte string. Always represents the canonical element.
--
local function encode(a)
a = canonicalize(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local bytes = {} local bytes = {}
local acc = a00 local acc = a00
local function putBytes(n) local function putBytes(n)
for _ = 1, n do for _ = 1, n do
local byte = acc % 256 local byte = acc % 256
@ -476,13 +754,18 @@ local function encode(a)
return string.char(unpack(bytes)) return string.char(unpack(bytes))
end end
--- Decodes an element in little-endian.
--
-- @tparam string b A 32-byte string. The most-significant bit is discarded.
-- @treturn fqAbs The decoded element. May not be canonical.
--
local function decode(b) local function decode(b)
local w00, w01, w02, w03, w04, w05, w06, w07, w08, w09, w10, w11 = local w00, w01, w02, w03, w04, w05, w06, w07, w08, w09, w10, w11 =
("<I3I3I2I3I3I2I3I3I2I3I3I2"):unpack(b) ("<I3I3I2I3I3I2I3I3I2I3I3I2"):unpack(b)
w11 = w11 % 2 ^ 15 w11 = w11 % 2 ^ 15
local out = { return carry {
w00, w00,
w01 * 2 ^ 24, w01 * 2 ^ 24,
w02 * 2 ^ 48, w02 * 2 ^ 48,
@ -496,18 +779,20 @@ local function decode(b)
w10 * 2 ^ 216, w10 * 2 ^ 216,
w11 * 2 ^ 240, w11 * 2 ^ 240,
} }
return kmul(out, 1)
end end
return { return {
num = num, num = num,
add = add, add = add,
neg = neg,
sub = sub, sub = sub,
kmul = kmul, kmul = kmul,
mul = mul, mul = mul,
canonicalize = canonicalize,
square = square, square = square,
carry = carry,
invert = invert, invert = invert,
sqrtDiv = sqrtDiv,
encode = encode, encode = encode,
decode = decode, decode = decode,
} }