diff --git a/ed25519.lua b/ed25519.lua index bb067d2..1385611 100644 --- a/ed25519.lua +++ b/ed25519.lua @@ -4,231 +4,11 @@ -- local expect = require "cc.expect".expect -local fp = require "ccryptolib.internal.fp" local fq = require "ccryptolib.internal.fq" local sha512 = require "ccryptolib.internal.sha512" +local ed = require "ccryptolib.internal.edwards25519" local random = require "ccryptolib.random" -local unpack = unpack or table.unpack - -local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666))) -local K = fp.kmul(D, 2) - -local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)} -local G = nil - -local function double(P1) - -- Unsoundness: fp.sub(g, e), and fp.sub(d, i) break fp.sub's contract since - -- it doesn't accept an fp2. Although not ideal, in practice this doesn't - -- matter since fp.carry handles the larger sum. - local P1x, P1y, P1z = unpack(P1) - local a = fp.square(P1x) - local b = fp.square(P1y) - local c = fp.square(P1z) - local d = fp.add(c, c) - local e = fp.add(a, b) - local f = fp.add(P1x, P1y) - local g = fp.square(f) - local h = fp.carry(fp.sub(g, e)) - local i = fp.sub(b, a) - local j = fp.carry(fp.sub(d, i)) - local P3x = fp.mul(h, j) - local P3y = fp.mul(i, e) - local P3z = fp.mul(j, i) - local P3t = fp.mul(h, e) - return {P3x, P3y, P3z, P3t} -end - -local function add(P1, N1) - local P1x, P1y, P1z, P1t = unpack(P1) - local N1p, N1m, N1z, N1t = unpack(N1) - local a = fp.sub(P1y, P1x) - local b = fp.mul(a, N1m) - local c = fp.add(P1y, P1x) - local d = fp.mul(c, N1p) - local e = fp.mul(P1t, N1t) - local f = fp.mul(P1z, N1z) - local g = fp.sub(d, b) - local h = fp.sub(f, e) - local i = fp.add(f, e) - local j = fp.add(d, b) - local P3x = fp.mul(g, h) - local P3y = fp.mul(i, j) - local P3z = fp.mul(h, i) - local P3t = fp.mul(g, j) - return {P3x, P3y, P3z, P3t} -end - -local function sub(P1, N1) - local P1x, P1y, P1z, P1t = unpack(P1) - local N1p, N1m, N1z, N1t = unpack(N1) - local a = fp.sub(P1y, P1x) - local b = fp.mul(a, N1p) - local c = fp.add(P1y, P1x) - local d = fp.mul(c, N1m) - local e = fp.mul(P1t, N1t) - local f = fp.mul(P1z, N1z) - local g = fp.sub(d, b) - local h = fp.add(f, e) - local i = fp.sub(f, e) - local j = fp.add(d, b) - local P3x = fp.mul(g, h) - local P3y = fp.mul(i, j) - local P3z = fp.mul(h, i) - local P3t = fp.mul(g, j) - return {P3x, P3y, P3z, P3t} -end - -local function niels(P1) - local P1x, P1y, P1z, P1t = unpack(P1) - local N3p = fp.add(P1y, P1x) - local N3m = fp.sub(P1y, P1x) - local N3z = fp.add(P1z, P1z) - local N3t = fp.mul(P1t, K) - return {N3p, N3m, N3z, N3t} -end - -local function scale(P1) - local P1x, P1y, P1z = unpack(P1) - local zInv = fp.invert(P1z) - local P3x = fp.mul(P1x, zInv) - local P3y = fp.mul(P1y, zInv) - local P3z = fp.num(1) - local P3t = fp.mul(P3x, P3y) - return {P3x, P3y, P3z, P3t} -end - -local function encode(P1) - local P1x, P1y = unpack(P1) - local y = fp.encode(P1y) - local xBit = fp.canonicalize(P1x)[1] % 2 - return y:sub(1, -2) .. string.char(y:byte(-1) + xBit * 128) -end - -local function decode(str) - local P3y = fp.decode(str) - local a = fp.square(P3y) - local b = fp.sub(a, fp.num(1)) - local c = fp.mul(a, D) - local d = fp.add(c, fp.num(1)) - local P3x = fp.sqrtDiv(b, d) - if not P3x then return nil end - local xBit = fp.canonicalize(P3x)[1] % 2 - if xBit ~= bit32.extract(str:byte(-1), 7) then - P3x = fp.carry(fp.sub(fp.P, P3x)) - end - local P3z = fp.num(1) - local P3t = fp.mul(P3x, P3y) - return {P3x, P3y, P3z, P3t} -end - -G = decode("Xfffffffffffffffffffffffffffffff") - -local function signedRadixW(bits, w) - -- TODO Find a more elegant way of doing this. - local wPow = 2 ^ w - local wPowh = wPow / 2 - local out = {} - local acc = 0 - local mul = 1 - for i = 1, #bits do - acc = acc + bits[i] * mul - mul = mul * 2 - while i == #bits and acc > 0 or mul > wPow do - local rem = acc % wPow - if rem >= wPowh then rem = rem - wPow end - acc = (acc - rem) / wPow - mul = mul / wPow - out[#out + 1] = rem - end - end - return out -end - -local function radixWTable(P, w) - local out = {} - for i = 1, 255 / w do - local row = {niels(P)} - for j = 2, 2 ^ w / 2 do - P = add(P, row[1]) - row[j] = niels(P) - end - out[i] = row - P = double(P) - end - return out -end - -local G_W = 5 -local G_TABLE = radixWTable(G, G_W) - -local function WNAF(bits, w) - -- TODO Find a more elegant way of doing this. - local wPow = 2 ^ w - local wPowh = wPow / 2 - local out = {} - local acc = 0 - local mul = 1 - for i = 1, #bits do - acc = acc + bits[i] * mul - mul = mul * 2 - while i == #bits and acc > 0 or mul > wPow do - if acc % 2 == 0 then - acc = acc / 2 - mul = mul / 2 - out[#out + 1] = 0 - else - local rem = acc % wPow - if rem >= wPowh then rem = rem - wPow end - acc = acc - rem - out[#out + 1] = rem - end - end - end - while out[#out] == 0 do out[#out] = nil end - return out -end - -local function WNAFTable(P, w) - local dP = double(P) - local out = {niels(P)} - for i = 3, 2 ^ w, 2 do - out[i] = niels(add(dP, out[i - 2])) - end - return out -end - -local function mulG(bits) - local sw = signedRadixW(bits, G_W) - local R = O - for i = 1, #sw do - local b = sw[i] - if b > 0 then - R = add(R, G_TABLE[i][b]) - elseif b < 0 then - R = sub(R, G_TABLE[i][-b]) - end - end - return R -end - -local function mul(P, bits) - local naf = WNAF(bits, 5) - local tbl = WNAFTable(P, 5) - local R = O - for i = #naf, 1, -1 do - local b = naf[i] - if b == 0 then - R = double(R) - elseif b > 0 then - R = add(R, tbl[b]) - else - R = sub(R, tbl[-b]) - end - end - return R -end - local mod = {} --- Computes a public key from a secret key. @@ -243,7 +23,7 @@ function mod.publicKey(sk) local h = sha512.digest(sk) local x = fq.decodeClamped(h:sub(1, 32)) - return encode(scale(mulG(fq.bits(x)))) + return ed.encode(ed.scale(ed.mulG(fq.bits(x)))) end --- Signs a message. @@ -266,8 +46,8 @@ function mod.sign(sk, pk, msg) -- Commitment. local k = fq.decodeWide(random.random(64)) - local r = mulG(fq.bits(k)) - local rStr = encode(scale(r)) + local r = ed.mulG(fq.bits(k)) + local rStr = ed.encode(ed.scale(r)) -- Challenge. local e = fq.decodeWide(sha512.digest(rStr .. pk .. msg)) @@ -294,7 +74,7 @@ function mod.verify(pk, msg, sig) expect(3, sig, "string") assert(#sig == 64, "signature length must be 64") - local y = decode(pk) + local y = ed.decode(pk) if not y then return nil end local rStr = sig:sub(1, 32) @@ -302,11 +82,11 @@ function mod.verify(pk, msg, sig) local e = fq.decodeWide(sha512.digest(rStr .. pk .. msg)) - local gs = mulG(fq.bits(fq.decode(sStr))) - local ye = mul(y, fq.bits(e)) - local rv = add(gs, niels(ye)) + local gs = ed.mulG(fq.bits(fq.decode(sStr))) + local ye = ed.mul(y, fq.bits(e)) + local rv = ed.add(gs, ed.niels(ye)) - return encode(scale(rv)) == rStr + return ed.encode(ed.scale(rv)) == rStr end return mod diff --git a/internal/curve25519.lua b/internal/curve25519.lua new file mode 100644 index 0000000..83b5999 --- /dev/null +++ b/internal/curve25519.lua @@ -0,0 +1,95 @@ +--- Point arithmetic on the Curve25519 Montgomery curve. +-- +-- :::note Internal Module +-- This module is meant for internal use within the library. Its API is unstable +-- and subject to change without major version bumps. +-- ::: +-- +--
+-- +-- @module[kind=internal] internal.curve25519 +-- + +local fp = require "ccryptolib.internal.fp" +local random = require "ccryptolib.random" + +local unpack = unpack or table.unpack + +local function double(x1, z1) + local a = fp.add(x1, z1) + local aa = fp.square(a) + local b = fp.sub(x1, z1) + local bb = fp.square(b) + local c = fp.sub(aa, bb) + local x3 = fp.mul(aa, bb) + local z3 = fp.mul(c, fp.add(bb, fp.kmul(c, 121666))) + return x3, z3 +end + +local function step(dxmul, dx, x1, z1, x2, z2) + local a = fp.add(x1, z1) + local aa = fp.square(a) + local b = fp.sub(x1, z1) + local bb = fp.square(b) + local e = fp.sub(aa, bb) + local c = fp.add(x2, z2) + local d = fp.sub(x2, z2) + local da = fp.mul(d, a) + local cb = fp.mul(c, b) + local x4 = fp.square(fp.add(da, cb)) + local z4 = dxmul(fp.square(fp.sub(da, cb)), dx) + local x3 = fp.mul(aa, bb) + local z3 = fp.mul(e, fp.add(bb, fp.kmul(e, 121666))) + return x3, z3, x4, z4 +end + +local function bits(str) + -- Decode. + local bytes = {str:byte(1, 32)} + local out = {} + for i = 1, 32 do + local byte = bytes[i] + for j = -7, 0 do + local bit = byte % 2 + out[8 * i + j] = bit + byte = (byte - bit) / 2 + end + end + + -- Clamp. + out[256] = 0 + out[255] = 1 + + -- We remove the 3 lowest bits since the ladder already multiplies by 8. + return {unpack(out, 4)} +end + +local function ladder8(dxmul, dx, bits) + local x1 = fp.num(1) + local z1 = fp.num(0) + + local z2 = fp.decode(random.random(32)) + local x2 = dxmul(z2, dx) + + -- Standard ladder. + for i = #bits, 1, -1 do + if bits[i] == 0 then + x1, z1, x2, z2 = step(dxmul, dx, x1, z1, x2, z2) + else + x2, z2, x1, z1 = step(dxmul, dx, x2, z2, x1, z1) + end + end + + -- Multiply by 8 (double 3 times). + for _ = 1, 3 do + x1, z1 = double(x1, z1) + end + + return fp.mul(x1, fp.invert(z1)) +end + +return { + double = double, + bits = bits, + ladder8 = ladder8, +} diff --git a/internal/edwards25519.lua b/internal/edwards25519.lua new file mode 100644 index 0000000..370ae0b --- /dev/null +++ b/internal/edwards25519.lua @@ -0,0 +1,244 @@ +--- Point arithmetic on the Edwards25519 Edwards curve. +-- +-- :::note Internal Module +-- This module is meant for internal use within the library. Its API is unstable +-- and subject to change without major version bumps. +-- ::: +-- +--
+-- +-- @module[kind=internal] internal.edwards25519 +-- + +local fp = require "ccryptolib.internal.fp" + +local unpack = unpack or table.unpack + +local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666))) +local K = fp.kmul(D, 2) + +local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)} +local G = nil + +local function double(P1) + -- Unsoundness: fp.sub(g, e), and fp.sub(d, i) break fp.sub's contract since + -- it doesn't accept an fp2. Although not ideal, in practice this doesn't + -- matter since fp.carry handles the larger sum. + local P1x, P1y, P1z = unpack(P1) + local a = fp.square(P1x) + local b = fp.square(P1y) + local c = fp.square(P1z) + local d = fp.add(c, c) + local e = fp.add(a, b) + local f = fp.add(P1x, P1y) + local g = fp.square(f) + local h = fp.carry(fp.sub(g, e)) + local i = fp.sub(b, a) + local j = fp.carry(fp.sub(d, i)) + local P3x = fp.mul(h, j) + local P3y = fp.mul(i, e) + local P3z = fp.mul(j, i) + local P3t = fp.mul(h, e) + return {P3x, P3y, P3z, P3t} +end + +local function add(P1, N1) + local P1x, P1y, P1z, P1t = unpack(P1) + local N1p, N1m, N1z, N1t = unpack(N1) + local a = fp.sub(P1y, P1x) + local b = fp.mul(a, N1m) + local c = fp.add(P1y, P1x) + local d = fp.mul(c, N1p) + local e = fp.mul(P1t, N1t) + local f = fp.mul(P1z, N1z) + local g = fp.sub(d, b) + local h = fp.sub(f, e) + local i = fp.add(f, e) + local j = fp.add(d, b) + local P3x = fp.mul(g, h) + local P3y = fp.mul(i, j) + local P3z = fp.mul(h, i) + local P3t = fp.mul(g, j) + return {P3x, P3y, P3z, P3t} +end + +local function sub(P1, N1) + local P1x, P1y, P1z, P1t = unpack(P1) + local N1p, N1m, N1z, N1t = unpack(N1) + local a = fp.sub(P1y, P1x) + local b = fp.mul(a, N1p) + local c = fp.add(P1y, P1x) + local d = fp.mul(c, N1m) + local e = fp.mul(P1t, N1t) + local f = fp.mul(P1z, N1z) + local g = fp.sub(d, b) + local h = fp.add(f, e) + local i = fp.sub(f, e) + local j = fp.add(d, b) + local P3x = fp.mul(g, h) + local P3y = fp.mul(i, j) + local P3z = fp.mul(h, i) + local P3t = fp.mul(g, j) + return {P3x, P3y, P3z, P3t} +end + +local function niels(P1) + local P1x, P1y, P1z, P1t = unpack(P1) + local N3p = fp.add(P1y, P1x) + local N3m = fp.sub(P1y, P1x) + local N3z = fp.add(P1z, P1z) + local N3t = fp.mul(P1t, K) + return {N3p, N3m, N3z, N3t} +end + +local function scale(P1) + local P1x, P1y, P1z = unpack(P1) + local zInv = fp.invert(P1z) + local P3x = fp.mul(P1x, zInv) + local P3y = fp.mul(P1y, zInv) + local P3z = fp.num(1) + local P3t = fp.mul(P3x, P3y) + return {P3x, P3y, P3z, P3t} +end + +local function encode(P1) + local P1x, P1y = unpack(P1) + local y = fp.encode(P1y) + local xBit = fp.canonicalize(P1x)[1] % 2 + return y:sub(1, -2) .. string.char(y:byte(-1) + xBit * 128) +end + +local function decode(str) + local P3y = fp.decode(str) + local a = fp.square(P3y) + local b = fp.sub(a, fp.num(1)) + local c = fp.mul(a, D) + local d = fp.add(c, fp.num(1)) + local P3x = fp.sqrtDiv(b, d) + if not P3x then return nil end + local xBit = fp.canonicalize(P3x)[1] % 2 + if xBit ~= bit32.extract(str:byte(-1), 7) then + P3x = fp.carry(fp.sub(fp.P, P3x)) + end + local P3z = fp.num(1) + local P3t = fp.mul(P3x, P3y) + return {P3x, P3y, P3z, P3t} +end + +G = decode("Xfffffffffffffffffffffffffffffff") + +local function signedRadixW(bits, w) + -- TODO Find a more elegant way of doing this. + local wPow = 2 ^ w + local wPowh = wPow / 2 + local out = {} + local acc = 0 + local mul = 1 + for i = 1, #bits do + acc = acc + bits[i] * mul + mul = mul * 2 + while i == #bits and acc > 0 or mul > wPow do + local rem = acc % wPow + if rem >= wPowh then rem = rem - wPow end + acc = (acc - rem) / wPow + mul = mul / wPow + out[#out + 1] = rem + end + end + return out +end + +local function radixWTable(P, w) + local out = {} + for i = 1, 255 / w do + local row = {niels(P)} + for j = 2, 2 ^ w / 2 do + P = add(P, row[1]) + row[j] = niels(P) + end + out[i] = row + P = double(P) + end + return out +end + +local G_W = 5 +local G_TABLE = radixWTable(G, G_W) + +local function WNAF(bits, w) + -- TODO Find a more elegant way of doing this. + local wPow = 2 ^ w + local wPowh = wPow / 2 + local out = {} + local acc = 0 + local mul = 1 + for i = 1, #bits do + acc = acc + bits[i] * mul + mul = mul * 2 + while i == #bits and acc > 0 or mul > wPow do + if acc % 2 == 0 then + acc = acc / 2 + mul = mul / 2 + out[#out + 1] = 0 + else + local rem = acc % wPow + if rem >= wPowh then rem = rem - wPow end + acc = acc - rem + out[#out + 1] = rem + end + end + end + while out[#out] == 0 do out[#out] = nil end + return out +end + +local function WNAFTable(P, w) + local dP = double(P) + local out = {niels(P)} + for i = 3, 2 ^ w, 2 do + out[i] = niels(add(dP, out[i - 2])) + end + return out +end + +local function mulG(bits) + local sw = signedRadixW(bits, G_W) + local R = O + for i = 1, #sw do + local b = sw[i] + if b > 0 then + R = add(R, G_TABLE[i][b]) + elseif b < 0 then + R = sub(R, G_TABLE[i][-b]) + end + end + return R +end + +local function mul(P, bits) + local naf = WNAF(bits, 5) + local tbl = WNAFTable(P, 5) + local R = O + for i = #naf, 1, -1 do + local b = naf[i] + if b == 0 then + R = double(R) + elseif b > 0 then + R = add(R, tbl[b]) + else + R = sub(R, tbl[-b]) + end + end + return R +end + +return { + double = double, + add = add, + niels = niels, + scale = scale, + encode = encode, + decode = decode, + mulG = mulG, + mul = mul, +} diff --git a/x25519.lua b/x25519.lua index 18ad851..c5b1d02 100644 --- a/x25519.lua +++ b/x25519.lua @@ -5,82 +5,7 @@ local expect = require "cc.expect".expect local fp = require "ccryptolib.internal.fp" -local random = require "ccryptolib.random" - -local unpack = unpack or table.unpack - -local function double(x1, z1) - local a = fp.add(x1, z1) - local aa = fp.square(a) - local b = fp.sub(x1, z1) - local bb = fp.square(b) - local c = fp.sub(aa, bb) - local x3 = fp.mul(aa, bb) - local z3 = fp.mul(c, fp.add(bb, fp.kmul(c, 121666))) - return x3, z3 -end - -local function step(dxmul, dx, x1, z1, x2, z2) - local a = fp.add(x1, z1) - local aa = fp.square(a) - local b = fp.sub(x1, z1) - local bb = fp.square(b) - local e = fp.sub(aa, bb) - local c = fp.add(x2, z2) - local d = fp.sub(x2, z2) - local da = fp.mul(d, a) - local cb = fp.mul(c, b) - local x4 = fp.square(fp.add(da, cb)) - local z4 = dxmul(fp.square(fp.sub(da, cb)), dx) - local x3 = fp.mul(aa, bb) - local z3 = fp.mul(e, fp.add(bb, fp.kmul(e, 121666))) - return x3, z3, x4, z4 -end - -local function bits(str) - -- Decode. - local bytes = {str:byte(1, 32)} - local out = {} - for i = 1, 32 do - local byte = bytes[i] - for j = -7, 0 do - local bit = byte % 2 - out[8 * i + j] = bit - byte = (byte - bit) / 2 - end - end - - -- Clamp. - out[256] = 0 - out[255] = 1 - - -- We remove the 3 lowest bits since the ladder already multiplies by 8. - return {unpack(out, 4)} -end - -local function ladder8(dxmul, dx, bits) - local x1 = fp.num(1) - local z1 = fp.num(0) - - local z2 = fp.decode(random.random(32)) - local x2 = dxmul(z2, dx) - - -- Standard ladder. - for i = #bits, 1, -1 do - if bits[i] == 0 then - x1, z1, x2, z2 = step(dxmul, dx, x1, z1, x2, z2) - else - x2, z2, x1, z1 = step(dxmul, dx, x2, z2, x1, z1) - end - end - - -- Multiply by 8 (double 3 times). - for _ = 1, 3 do - x1, z1 = double(x1, z1) - end - - return fp.mul(x1, fp.invert(z1)) -end +local mont = require "ccryptolib.internal.curve25519" local mod = {} @@ -92,7 +17,7 @@ local mod = {} function mod.publicKey(sk) expect(1, sk, "string") assert(#sk == 32, "secret key length must be 32") - return fp.encode(ladder8(fp.kmul, 9, bits(sk))) + return fp.encode(mont.ladder8(fp.kmul, 9, mont.bits(sk))) end --- Performs the key exchange. @@ -106,7 +31,7 @@ function mod.exchange(sk, pk) assert(#sk == 32, "secret key length must be 32") expect(2, pk, "string") assert(#pk == 32, "public key length must be 32") - return fp.encode(ladder8(fp.mul, fp.decode(pk), bits(sk))) + return fp.encode(mont.ladder8(fp.mul, fp.decode(pk), mont.bits(sk))) end return mod