Add PRAC-based twofold multiplication

This commit is contained in:
Miguel Oliveira 2022-04-08 11:56:03 -03:00
parent a57c5e1ded
commit db4c272aea
No known key found for this signature in database
GPG key ID: 2C2BE789E1377025
8 changed files with 506 additions and 59 deletions

View file

@ -11,6 +11,7 @@
--
local fp = require "ccryptolib.internal.fp"
local ed = require "ccryptolib.internal.edwards25519"
local random = require "ccryptolib.random"
local function double(P1)
@ -25,7 +26,31 @@ local function double(P1)
return {x3, z3}
end
local function step(dxmul, dx, P1, P2)
local function dadd(DP, P1, P2)
local dx, dz = DP[1], DP[2]
local x1, z1 = P1[1], P1[2]
local x2, z2 = P2[1], P2[2]
local a = fp.add(x1, z1)
local b = fp.sub(x1, z1)
local c = fp.add(x2, z2)
local d = fp.sub(x2, z2)
local da = fp.mul(d, a)
local cb = fp.mul(c, b)
local x3 = fp.mul(dz, fp.square(fp.add(da, cb)))
local z3 = fp.mul(dx, fp.square(fp.sub(da, cb)))
return {x3, z3}
end
--- Performs a step on the Montgomery ladder.
--
-- @param C A - B.
-- @param A The first point.
-- @param B The second point.
-- @return 2A
-- @return A + B
--
local function step(DP, P1, P2)
local dx, dz = DP[1], DP[2]
local x1, z1 = P1[1], P1[2]
local x2, z2 = P2[1], P2[2]
local a = fp.add(x1, z1)
@ -37,40 +62,213 @@ local function step(dxmul, dx, P1, P2)
local d = fp.sub(x2, z2)
local da = fp.mul(d, a)
local cb = fp.mul(c, b)
local x4 = fp.square(fp.add(da, cb))
local z4 = dxmul(fp.square(fp.sub(da, cb)), dx)
local x4 = fp.mul(dz, fp.square(fp.add(da, cb)))
local z4 = fp.mul(dx, fp.square(fp.sub(da, cb)))
local x3 = fp.mul(aa, bb)
local z3 = fp.mul(e, fp.add(bb, fp.kmul(e, 121666)))
return {x3, z3}, {x4, z4}
end
--- Performs a Montgomery ladder operation with multiplication by 8.
--
-- @tparam function(a:internal.fp.fp1, dx:any):internal.fp.fpq dxmul A function
-- to multiply an element in Fp by dx.
-- @tparam any dx The base point's x coordinate. Z is assumed to be equal to 1.
-- @tparam {number...} bits The multiplier scalar divided by 8, in little-endian
-- bits.
--
local function ladder8(dxmul, dx, bits)
local function ladder(DP, bits)
local P = {fp.num(1), fp.num(0)}
local Q = {fp.decode(random.random(32)), dxmul(z2, dx)}
local Q = DP
-- Standard ladder.
for i = #bits, 1, -1 do
if bits[i] == 0 then
P, Q = step(dxmul, dx, P, Q)
P, Q = step(DP, P, Q)
else
Q, P = step(dxmul, dx, Q, P)
Q, P = step(DP, Q, P)
end
end
-- Multiply by 8 (double 3 times).
for _ = 1, 3 do P = double(P) end
return P
end
return fp.mul(P[1], fp.invert(P[2]))
--- Performs a scalar multiplication operation with multiplication by 8.
--
-- @tparam point P The base point.
-- @tparam {number...} bits The scalar multiplier, in little-endian bits.
-- @treturn point The product, multiplied by 8.
--
local function ladder8(P, bits)
-- Randomize.
local rf = fp.decode(random.random(32))
P = {fp.mul(P[1], rf), fp.mul(P[2], rf)}
-- Multiply.
return double(double(double(ladder(P, bits))))
end
local function scale(P)
return {fp.mul(P[1], fp.invert(P[2])), fp.num(1)}
end
--- Encodes a point.
--
-- @tparam point P1 The scaled point to encode.
-- @treturn string The 32-byte encoded point.
--
local function encode(P)
return fp.encode(P[1])
end
--- Decodes a point.
--
-- @tparam string str A 32-byte encoded point.
-- @treturn point The decoded point.
--
local function decode(str)
return {fp.decode(str), fp.num(1)}
end
--- Performs a scalar multiplication by the base point G.
--
-- @tparam {number...} bits The scalar multiplier, in little-endian bits.
-- @return The product point.
--
local function mulG(bits)
-- Multiply by G on Edwards25519.
local P = ed.mulG(bits)
-- Use the birational map to get the point on Curve25519.
-- Never fails since G is in the large group, and the exponent is clamped.
local Py, Pz = P[2], P[3]
local Rx = fp.carry(fp.add(Py, Pz))
local Rz = fp.carry(fp.sub(Pz, Py))
return {Rx, Rz}
end
--- Computes a twofold product from a ruleset.
--
-- @tparam point P The base point.
-- @tparam {{number...}, {number...}} The ruleset generated by scalars m, n.
-- @treturn point [8m]P
-- @treturn point [8n]P
-- @treturn point [8m]P - [8n]P
--
local function prac(P, ruleset)
-- Randomize.
local rf = fp.decode(random.random(32))
local A = {fp.mul(P[1], rf), fp.mul(P[2], rf)}
-- Start the base at [8]P.
local A = double(double(double(A)))
-- Throw away small order points.
if fp.eqz(fp.mul(A[1], A[2])) then return A end
-- Now e = d = gcd(m, n) / 8.
-- Update A from [8]P to [8u]P.
A = ladder(A, ruleset[1])
-- Reject rulesets where m = n.
local rules = ruleset[2]
if #rules == 0 then return nil end
-- Evaluate the first rule.
-- Since e = d, this means A - B = C = O. Differential addition fails when
-- C = O, so we need to treat this case specially.
-- Note that rules 0 and 1 never happen last, since the algorithm would stop
-- one step earlier if they did.
local B, C
local rule = rules[#rules]
if rule == 2 then
-- (A, B, C) ← (2A + B, B, 2A) = (3A, A, 2A)
local A2 = double(A)
A, B, C = dadd(A, A2, A), A, A2
elseif rule == 3 or rule == 5 then
-- (A, B, C) ← (A + B, B, A) = (2A, A, A)
-- or (A, B, C) ← (2A, B, 2A - B) = (2A, A, A)
A, B, C = double(A), A, A
elseif rule == 6 then
-- (A, B, C) ← (3A + 3B, B, 3A + 2B) = (6A, A, 5A)
local A2 = double(A)
local A3 = dadd(A, A2, A)
A, B, C = double(A3), A, dadd(A, A3, A2)
elseif rule == 7 then
-- (A, B, C) ← (3A + 2B, B, 3A + B) = (5A, A, 4A)
local A2 = double(A)
local A3 = dadd(A, A2, A)
local A4 = double(A2)
A, B, C = dadd(A3, A4, A), A, A4
elseif rule == 8 then
-- (A, B, C) ← (3A + B, B, 3A) = (4A, A, 3A)
local A2 = double(A)
local A3 = dadd(A, A2, A)
A, B, C = double(A2), A, A3
else
-- (A, B, C) ← (A, 2B, A - 2B) = (A, 2A, A)
A, B, C = A, double(A), A
end
-- Evaluate the other rules.
-- Let's assume addition is undefined here, this happens when A - B = 0.
-- Since A = [d]P and B = [e]P, A = B happens when:
-- (1) P is on the large order base group and d ≡ e (mod q).
-- (2) P is on the large order twist group and d ≡ e (mod q').
-- (3) P is on a small order group.
-- Case (3) never happens since we throw small order points away above.
-- Since 0 ≤ {d, e} < q < q', a modular equivalence here means an integer
-- equivalence. Therefore d = e.
-- However, the ruleset stops when d = e, therefore the algorithm must have
-- stopped earlier than when it did. Contradiction.
-- Therefore, addition is always defined.
-- Furthermore, the PRAC invariants mean that this product is the same as
-- if the points were multiplied separately.
for i = #rules - 1, 1, -1 do
local rule = rules[i]
if rule == 0 then
-- (A, B, C) ← (B, A, B - A)
A, B = B, A
elseif rule == 1 then
-- (A, B, C) ← (2A + B, A + 2B, A - B)
local AB = dadd(C, A, B)
A, B = dadd(B, AB, A), dadd(A, AB, B)
elseif rule == 2 then
-- (A, B, C) ← (2A + B, B, 2A)
A, C = dadd(B, dadd(C, A, B), A), double(A)
elseif rule == 3 then
-- (A, B, C) ← (A + B, B, A)
A, C = dadd(C, A, B), A
elseif rule == 5 then
-- (A, B, C) ← (2A, B, 2A - B)
A, C = double(A), dadd(B, A, C)
elseif rule == 6 then
-- (A, B, C) ← (3A + 3B, B, 3A + 2B)
local AB = dadd(C, A, B)
local AABB = double(AB)
A, C = dadd(AB, AABB, AB), dadd(dadd(A, AB, B), AABB, A)
elseif rule == 7 then
-- (A, B, C) ← (3A + 2B, B, 3A + B)
local AB = dadd(C, A, B)
local AAB = dadd(B, AB, A)
A, C = dadd(A, AAB, AB), dadd(AB, AAB, A)
elseif rule == 8 then
-- (A, B, C) ← (3A + B, B, 3A)
local AA = double(A)
A, C = dadd(C, AA, dadd(C, A, B)), dadd(A, AA, A)
else
-- (A, B, C) ← (A, 2B, A - 2B)
B, C = double(B), dadd(A, C, B)
end
end
return A, B, C
end
local function fieldMul(P, m)
return {fp.mul(P[1], fp.decode(m)), P[2]}
end
return {
G = {fp.num(9), fp.num(1)},
dadd = dadd,
scale = scale,
encode = encode,
decode = decode,
ladder8 = ladder8,
mulG = mulG,
prac = prac,
fieldMul = fieldMul,
}

View file

@ -178,7 +178,7 @@ end
local function radixWTable(P, w)
local out = {}
for i = 1, math.ceil(255 / w) do
for i = 1, math.ceil(256 / w) do
local row = {niels(P)}
for j = 2, 2 ^ w / 2 do
P = add(P, row[1])

View file

@ -771,6 +771,18 @@ local function decode(b)
}
end
--- Checks if two elements are equal.
--
-- @tparam fp2 a
-- @treturn boolean Whether a ≡ 0 (mod p).
--
local function eqz(a)
local c = canonicalize(a)
local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 = unpack(c)
return c00 + c01 + c02 + c03 + c04 + c05 + c06 + c07 + c08 + c09 + c10 + c11
== 0
end
return {
num = num,
neg = neg,
@ -785,4 +797,5 @@ return {
sqrtDiv = sqrtDiv,
encode = encode,
decode = decode,
eqz = eqz,
}

View file

@ -60,6 +60,20 @@ local T1 = {
00000283,
}
local T8 = {
01130678,
05563041,
03870191,
01622646,
01247520,
12151703,
16693196,
09337410,
04700637,
07308819,
00002083,
}
local ZERO = mp.num(0)
--- Reduces a number modulo q.
@ -103,6 +117,19 @@ local function neg(a)
return reduce(mp.sub(Q, a))
end
--- Subtracts scalars mod q.
--
-- If the two operands are in Montgomery form, returns the correct result also
-- in Montgomery form, since (2²⁶⁴ × a) - (2²⁶⁴ × b) ≡ 2²⁶⁴ × (a - b) (mod q).
--
-- @tparam {number...} a A number a < q as 11 limbs in [0..2²⁴).
-- @tparam {number...} b A number b < q as 11 limbs in [0..2²⁴).
-- @treturn {number...} a - b mod q as 11 limbs in [0..2²⁴).
--
local function sub(a, b)
return add(a, neg(b))
end
--- Given two scalars a and b, computes 2⁻²⁶⁴ × a × b mod q.
--
-- @tparam {number...} a A number a as 11 limbs in [0..2²⁴).
@ -194,6 +221,24 @@ local function decodeClamped(str)
return montgomery(words)
end
--- Decodes a scalar using the X25519/Ed25519 bit clamping scheme and division
-- by 8.
--
-- @tparam string str A 32-byte string encoding some little-endian number a.
-- @treturn {number...} 2²⁶⁴ × clamp(a) ÷ 8 mod q as 11 limbs in [0..2²⁴).
--
local function decodeClamped8(str)
-- Decode.
local words = {("<I3I3I3I3I3I3I3I3I3I3I2"):unpack(str)} words[12] = nil
-- Clamp.
words[1] = bit32.band(words[1], 0xfffff8)
words[11] = bit32.band(words[11], 0x7fff)
words[11] = bit32.bor(words[11], 0x4000)
return mul(words, T8)
end
--- Returns a scalar in binary.
--
-- @tparam {number...} a A number a < q as 11 limbs in [0..2²⁴).
@ -203,26 +248,147 @@ local function bits(a)
return util.rebaseLE(demontgomery(a), 2 ^ 24, 2)
end
--- Clones a scalar.
--- Makes a PRAC ruleset from a pair of scalars.
--
-- @tparam {number...} a The scalar to clone.
-- @treturn {number...} The exact same value but as a different object.
-- @tparam {number...} a A scalar a < q as 11 limbs in [0..2²⁴).
-- @tparam {number...} b A scalar b < q as 11 limbs in [0..2²⁴).
-- @treturn {{number...}, {number...}} The generated ruleset.
--
local function clone(a)
return {unpack(a)}
local function makeRuleset(a, b)
-- The numbers in raw multiprecision tables.
local dt = demontgomery(a) -- (-2²⁴..2²⁴)
local et = demontgomery(b) -- (-2²⁴..2²⁴)
local ft = mp.sub(dt, et) -- (-2²⁵..2²⁵)
-- Residue classes of (d, e) modulo 2.
local d2 = mp.mod2(dt)
local e2 = mp.mod2(et)
-- Residue classes of (d, e) modulo 3.
local d3 = mp.mod3(dt)
local e3 = mp.mod3(et)
-- (e, d - e) in limited-precision floating-point numbers.
local ef = mp.approx(et)
local ff = mp.approx(ft)
-- Lookup table for inversions and halvings modulo 3.
local lut3 = {[0] = 0, 2, 1}
local rules = {}
while ff ~= 0 do
if ff < 0 then
-- M0.
rules[#rules + 1] = 0
-- (d, e) ← (e, d)
dt, et = et, dt
d2, e2 = e2, d2
d3, e3 = e3, d3
ef = mp.approx(et)
ft = mp.sub(dt, et)
ff = -ff
elseif 4 * ff < ef and d3 == lut3[e3] then
-- M1.
rules[#rules + 1] = 1
-- (d, e) ← ((2d - e)/3, (2e - d)/3)
dt, et = mp.third(mp.add(dt, ft)), mp.third(mp.sub(et, ft))
d2, e2 = e2, d2
d3, e3 = mp.mod3(dt), mp.mod3(et)
ef = mp.approx(et)
elseif 4 * ff < ef and d2 == e2 and d3 == e3 then
-- M2.
rules[#rules + 1] = 2
-- (d, e) ← ((d - e)/2, e)
dt = mp.half(ft)
d2 = mp.mod2(dt)
d3 = lut3[(d3 - e3) % 3]
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif ff < 3 * ef then
-- M3.
rules[#rules + 1] = 3
-- (d, e) ← (d - e, e)
dt = mp.carryWeak(ft)
d2 = (d2 - e2) % 2
d3 = (d3 - e3) % 3
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d2 == e2 then
-- M4 (same as M2).
rules[#rules + 1] = 2
-- (d, e) ← ((d - e)/2, e)
dt = mp.half(ft)
d2 = mp.mod2(dt)
d3 = lut3[(d3 - e3) % 3]
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d2 == 0 then
-- M5.
rules[#rules + 1] = 5
-- (d, e) ← (d/2, e)
dt = mp.half(dt)
d2 = mp.mod2(dt)
d3 = lut3[d3]
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d3 == 0 then
-- M6.
rules[#rules + 1] = 6
-- (d, e) ← (d/3 - e, e)
dt = mp.carryWeak(mp.sub(mp.third(dt), et))
d2 = (d2 - e2) % 2
d3 = mp.mod3(dt)
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d3 == lut3[e3] then
-- M7.
rules[#rules + 1] = 7
-- (d, e) ← ((d - 2e)/3, e)
dt = mp.third(mp.sub(ft, et))
d3 = mp.mod3(dt)
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d3 == e3 then
-- M8.
rules[#rules + 1] = 8
-- (d, e) ← ((d - e)/3, e)
dt = mp.third(ft)
d2 = (d2 - e2) % 2
d3 = mp.mod3(dt)
ft = mp.sub(dt, et)
ff = mp.approx(ft)
else
-- M9.
rules[#rules + 1] = 9
-- (d, e) ← (d, e/2)
et = mp.half(et)
e2 = mp.mod2(et)
e3 = lut3[e3]
ef = mp.approx(et)
ft = mp.sub(dt, et)
ff = mp.approx(ft)
end
end
local ubits = util.rebaseLE(dt, 2 ^ 24, 2)
while ubits[#ubits] == 0 do ubits[#ubits] = nil end
return {ubits, rules}
end
return {
num = num,
add = add,
neg = neg,
sub = sub,
montgomery = montgomery,
demontgomery = demontgomery,
mul = mul,
encode = encode,
decode = decode,
decodeWide = decodeWide,
decodeClamped8 = decodeClamped8,
decodeClamped = decodeClamped,
bits = bits,
clone = clone,
makeRuleset = makeRuleset,
}

View file

@ -209,13 +209,13 @@ end
--- Computes half of a number.
--
-- @tparam {number...} a An even positive integer as 11 limbs in (-2²⁴..2²⁴).
-- @treturn {number...} a ÷ 2 as 11 limbs in (-2..2⁴).
-- @treturn {number...} a ÷ 2 as 11 limbs in (-2²⁴..2²⁴).
--
local function half(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
return {
a00 + a01 * 2 ^ 23,
local out = carryWeak {
a00 * 0.5 + a01 * 2 ^ 23,
a02 * 2 ^ 23,
a03 * 2 ^ 23,
a04 * 2 ^ 23,
@ -227,6 +227,10 @@ local function half(a)
a10 * 2 ^ 23,
0,
}
out[12] = nil
return out
end
--- Computes a third of a number.

View file

@ -1,5 +1,3 @@
local mod = {}
--- Converts a little-endian array from one power-of-two base to another.
--
-- @tparam {number...} a The array to convert, in little-endian.
@ -7,7 +5,7 @@ local mod = {}
-- @tparam number base2 The base to convert to. Must be a power of 2.
-- @treturn {number...}
--
function mod.rebaseLE(a, base1, base2)
local function rebaseLE(a, base1, base2) -- TODO Write contract properly.
local out = {}
local outlen = 1
local acc = 0
@ -29,4 +27,46 @@ function mod.rebaseLE(a, base1, base2)
return out
end
return mod
--- Decodes bits with X25519/Ed25519 exponent clamping.
--
-- @taparm string str The 32-byte encoded exponent.
-- @treturn {number...} The decoded clamped bits.
--
local function bits(str)
-- Decode.
local bytes = {str:byte(1, 32)}
local out = {}
for i = 1, 32 do
local byte = bytes[i]
for j = -7, 0 do
local bit = byte % 2
out[8 * i + j] = bit
byte = (byte - bit) / 2
end
end
-- Clamp.
out[1] = 0
out[2] = 0
out[3] = 0
out[255] = 1
out[256] = 0
-- We remove the 3 lowest bits since the ladder already multiplies by 8.
return out
end
--- Decodes bits with X25519/Ed25519 exponent clamping and division by 8.
--
-- @taparm string str The 32-byte encoded exponent.
-- @treturn {number...} The decoded clamped bits, divided by 8.
--
local function bits8(str)
return {unpack(bits(str), 4)}
end
return {
rebaseLE = rebaseLE,
bits = bits,
bits8 = bits8,
}

View file

@ -4,30 +4,8 @@
--
local expect = require "cc.expect".expect
local fp = require "ccryptolib.internal.fp"
local mont = require "ccryptolib.internal.curve25519"
-- TODO This function feels out of place anywhere I try putting it on.
local function bits(str)
-- Decode.
local bytes = {str:byte(1, 32)}
local out = {}
for i = 1, 32 do
local byte = bytes[i]
for j = -7, 0 do
local bit = byte % 2
out[8 * i + j] = bit
byte = (byte - bit) / 2
end
end
-- Clamp.
out[256] = 0
out[255] = 1
-- We remove the 3 lowest bits since the ladder already multiplies by 8.
return {unpack(out, 4)}
end
local util = require "ccryptolib.internal.util"
local c25 = require "ccryptolib.internal.curve25519"
local mod = {}
@ -39,7 +17,7 @@ local mod = {}
function mod.publicKey(sk)
expect(1, sk, "string")
assert(#sk == 32, "secret key length must be 32")
return fp.encode(mont.ladder8(fp.kmul, 9, bits(sk)))
return c25.encode(c25.scale(c25.mulG(util.bits(sk))))
end
--- Performs the key exchange.
@ -53,7 +31,7 @@ function mod.exchange(sk, pk)
assert(#sk == 32, "secret key length must be 32")
expect(2, pk, "string")
assert(#pk == 32, "public key length must be 32")
return fp.encode(mont.ladder8(fp.mul, fp.decode(pk), bits(sk)))
return c25.encode(c25.scale(c25.ladder8(c25.decode(pk), util.bits8(sk))))
end
return mod

48
x25519c.lua Normal file
View file

@ -0,0 +1,48 @@
local expect = require "cc.expect".expect
local fq = require "ccryptolib.internal.fq"
local util = require "ccryptolib.internal.util"
local c25 = require "ccryptolib.internal.curve25519"
local random = require "ccryptolib.random"
local mod = {}
function mod.keypair()
local x = random.random(32)
local r = random.random(32)
local X = c25.mulG(util.bits(x))
local x8 = fq.decodeClamped8(x)
local r8 = fq.decodeClamped8(r)
local xr8 = fq.sub(x8, r8)
return fq.encode(xr8), r, c25.encode(c25.scale(X))
end
function mod.remask(sk, ek)
expect(1, sk, "string")
assert(#sk == 32, "secret key length must be 32")
expect(2, ek, "string")
assert(#ek == 32, "ephemeral secret key length must be 32")
local s = random.random(32)
local r8 = fq.decodeClamped8(ek)
local s8 = fq.decodeClamped8(s)
local xr8 = fq.decode(sk)
local xs8 = fq.add(xr8, fq.sub(r8, s8))
return fq.encode(xs8), s
end
function mod.exchange(sk, ek, pk)
expect(1, sk, "string")
assert(#sk == 32, "secret key length must be 32")
expect(2, ek, "string")
assert(#ek == 32, "ephemeral secret key length must be 32")
expect(3, pk, "string")
assert(#pk == 32, "public key length must be 32")
local P = c25.decode(pk)
local r8 = fq.decodeClamped8(ek)
local xr8 = fq.decode(sk)
local ruleset = fq.makeRuleset(r8, xr8)
local rP, xrP, dP = c25.prac(P, ruleset)
local xP = c25.dadd(dP, rP, xrP)
return c25.encode(c25.scale(xP)), c25.encode(c25.scale(rP))
end
return mod