diff --git a/internal/fp.lua b/internal/fp.lua index c4384db..361659a 100644 --- a/internal/fp.lua +++ b/internal/fp.lua @@ -1,5 +1,9 @@ local unpack = unpack or table.unpack +local function num(n) + return {n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} +end + local function add(a, b) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) @@ -497,6 +501,7 @@ local function decode(b) end return { + num = num, add = add, sub = sub, kmul = kmul, diff --git a/internal/x25519.lua b/internal/x25519.lua index 23ba994..e6b9886 100644 --- a/internal/x25519.lua +++ b/internal/x25519.lua @@ -2,6 +2,17 @@ local fp = require "ccryptolib.internal.fp" local G = {9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} +local function double(x1, z1) + local a = fp.add(x1, z1) + local aa = fp.square(a) + local b = fp.sub(x1, z1) + local bb = fp.square(b) + local c = fp.sub(aa, bb) + local x3 = fp.mul(aa, bb) + local z3 = fp.mul(c, fp.add(bb, fp.kmul(c, 121666))) + return x3, z3 +end + local function step(dx, x1, z1, x2, z2) local a = fp.add(x1, z1) local aa = fp.square(a) @@ -19,48 +30,8 @@ local function step(dx, x1, z1, x2, z2) return x3, z3, x4, z4 end -local function ladder(dx, bits) - local x1 = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - local z1 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - local x2, z2 = dx, x1 - - for i = #bits, 1, -1 do - if bits[i] == 0 then - x1, z1, x2, z2 = step(dx, x1, z1, x2, z2) - else - x2, z2, x1, z1 = step(dx, x2, z2, x1, z1) - end - end - - return fp.mul(x1, fp.invert(z1)) -end - -local function bits(str) - -- Decode. - local bytes = {str:byte(1, 32)} - local out = {} - for i = 1, 32 do - local byte = bytes[i] - for j = -7, 0 do - local bit = byte % 2 - out[8 * i + j] = bit - byte = (byte - bit) / 2 - end - end - - -- Clamp. - out[1] = 0 - out[2] = 0 - out[3] = 0 - out[256] = 0 - out[255] = 1 - - return out -end - return { G = G, + double = double, step = step, - ladder = ladder, - bits = bits, } diff --git a/x25519.lua b/x25519.lua index 65ea4ab..4686b29 100644 --- a/x25519.lua +++ b/x25519.lua @@ -2,13 +2,58 @@ local expect = require "cc.expect".expect local fp = require "ccryptolib.internal.fp" local x25519 = require "ccryptolib.internal.x25519" +local unpack = unpack or table.unpack + +local function bits(str) + -- Decode. + local bytes = {str:byte(1, 32)} + local out = {} + for i = 1, 32 do + local byte = bytes[i] + for j = -7, 0 do + local bit = byte % 2 + out[8 * i + j] = bit + byte = (byte - bit) / 2 + end + end + + -- Clamp. + out[256] = 0 + out[255] = 1 + + -- We remove the 3 lowest bits since the ladder already multiplies by 8. + return {unpack(out, 4)} +end + +local function ladder8(dx, bits) + local x1 = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + local z1 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + local x2, z2 = dx, x1 + + -- Standard ladder. + for i = #bits, 1, -1 do + if bits[i] == 0 then + x1, z1, x2, z2 = x25519.step(dx, x1, z1, x2, z2) + else + x2, z2, x1, z1 = x25519.step(dx, x2, z2, x1, z1) + end + end + + -- Multiply by 8 (double 3 times). + for _ = 1, 3 do + x1, z1 = x25519.double(x1, z1) + end + + return fp.mul(x1, fp.invert(z1)) +end + local mod = {} function mod.publicKey(sk) expect(1, sk, "string") assert(#sk == 32, "secret key length must be 32") - return fp.encode(x25519.ladder(x25519.G, x25519.bits(sk))) + return fp.encode(ladder8(x25519.G, bits(sk))) end function mod.exchange(sk, pk) @@ -17,7 +62,7 @@ function mod.exchange(sk, pk) expect(2, pk, "string") assert(#pk == 32, "public key length must be 32") - return fp.encode(x25519.ladder(fp.decode(pk), x25519.bits(sk))) + return fp.encode(ladder8(fp.decode(pk), bits(sk))) end return mod diff --git a/x25519c.lua b/x25519c.lua index 1854c70..e5e5c64 100644 --- a/x25519c.lua +++ b/x25519c.lua @@ -37,6 +37,34 @@ local function fqDecodeStd(str) return fq.montgomery(words) end +local function ladder8(dx, bits) + local x1 = fp.num(1) + local z1 = fp.num(0) + + -- Compute a randomization factor for randomized projective coordinates. + -- Biased but good enough. + local rf = fp.decode(random.random(32)) + + local x2 = fp.mul(rf, dx) + local z2 = rf + + -- Standard ladder. + for i = #bits, 1, -1 do + if bits[i] == 0 then + x1, z1, x2, z2 = x25519.step(dx, x1, z1, x2, z2) + else + x2, z2, x1, z1 = x25519.step(dx, x2, z2, x1, z1) + end + end + + -- Multiply by 8 (double 3 times). + for _ = 1, 3 do + x1, z1 = x25519.double(x1, z1) + end + + return fp.mul(x1, fp.invert(z1)) +end + local mod = {} function mod.secretKeyInit(sk) @@ -102,10 +130,8 @@ function mod.exchange(sk, pk, mc) -- We have our exponent modulo q. We also know that its value is 0 modulo 8. -- Use the Chinese Remainder Theorem to find its value modulo 8q. local bits = fq.bits(fq.mul(skmt, INV8Q)) - local bits8 = {0, 0, 0} - for i = 1, 253 do bits8[i + 3] = bits[i] end - return fp.encode(x25519.ladder(fp.decode(pk), bits8)) + return fp.encode(ladder8(fp.decode(pk), bits)) end return mod