From d3f04036c79fc375183e643559d913c938ad79d7 Mon Sep 17 00:00:00 2001 From: Miguel Oliveira Date: Fri, 4 Mar 2022 12:15:37 -0300 Subject: [PATCH] Split Fq masking internals --- internal/maddq.lua | 47 +++++++++++++++++++++++++++++++++++ x25519c.lua | 61 ++++++++++++---------------------------------- 2 files changed, 62 insertions(+), 46 deletions(-) create mode 100644 internal/maddq.lua diff --git a/internal/maddq.lua b/internal/maddq.lua new file mode 100644 index 0000000..51e7e79 --- /dev/null +++ b/internal/maddq.lua @@ -0,0 +1,47 @@ +local fq = require "ccryptolib.internal.fq" +local random = require "ccryptolib.random" + +local function new(val, order) + local out = {} + local sum = fq.num(0) + for i = 1, order - 1 do + out[i] = fq.decodeWide(random.random(64)) + sum = fq.add(sum, out[i]) + end + + out[order] = fq.add(val, fq.neg(sum)) + + return out +end + +local function encode(arr) + local out = {} + for i = 1, #arr do out[i] = fq.encode(arr[i]) end + return table.concat(out) +end + +local function decode(str) + local out = {} + for i = 1, #str / 32 do out[i] = fq.decode(str:sub(i * 32 - 31, i * 32)) end + return out +end + +local function remask(arr) + local out = new(fq.num(0), #arr) + for i = 1, #arr do out[i] = fq.add(out[i], arr[i]) end + return out +end + +local function reduce(arr, k) + local out = fq.num(0) + for i = 1, #arr do out = fq.add(out, fq.mul(arr[i], k)) end + return out +end + +return { + new = new, + encode = encode, + decode = decode, + remask = remask, + reduce = reduce, +} diff --git a/x25519c.lua b/x25519c.lua index 32ca452..be7add5 100644 --- a/x25519c.lua +++ b/x25519c.lua @@ -2,6 +2,7 @@ local expect = require "cc.expect".expect local fp = require "ccryptolib.internal.fp" local fq = require "ccryptolib.internal.fq" local x25519 = require "ccryptolib.internal.x25519" +local maddq = require "ccryptolib.internal.maddq" local random = require "ccryptolib.random" local ORDER = 4 @@ -21,10 +22,6 @@ local INV8Q = { 4095, } -local function fqRandom() - return fq.decodeWide(random.random(64)) -end - local function ladder8(dx, bits) local x1 = fp.num(1) local z1 = fp.num(0) @@ -55,49 +52,26 @@ end local mod = {} -function mod.secretKeyInit(sk) - sk = fq.decodeClamped(sk) +function mod.new(sk) + expect(1, sk, "string") + assert(#sk == 32, "secret key length must be 32") - -- Set up the mask. - local sks = {} - local sum = fq.num(0) - for i = 1, ORDER - 1 do - sks[i] = fqRandom() - sum = fq.add(sum, sks[i]) - end - sks[ORDER] = fq.add(sk, fq.neg(sum)) - - return sks + return maddq.new(fq.decodeClamped(sk), ORDER) end -function mod.secretKeyEncode(sks) - local out = {} - for i = 1, ORDER do out[i] = fq.encode(sks[i]) end - return table.concat(out) +function mod.encode(sks) + return maddq.encode(sks) end -function mod.secretKeyDecode(str) +function mod.decode(str) expect(1, str, "string") - assert(#str == ORDER * 32, ("secret key length must be %d"):format(ORDER * 32)) + assert(#str == 128, "encoded sks length must be 128") - local out = {} - for i = 1, ORDER do out[i] = fq.decode(str:sub(i * 32 - 31, i * 32)) end - return out + return maddq.decode(str) end -function mod.secretKeyRemask(sks) - local sum = fq.num(0) - local out = {} - - for i = 1, ORDER - 1 do - local element = fqRandom() - out[i] = fq.add(sks[i], element) - sum = fq.add(sum, element) - end - - out[ORDER] = fq.add(sks[ORDER], fq.neg(sum)) - - return out +function mod.remask(sks) + return maddq.remask(sks) end function mod.exchange(sks, pk, mc) @@ -106,18 +80,13 @@ function mod.exchange(sks, pk, mc) expect(3, mc, "string") assert(#mc == 32, "multiplier length must be 32") - -- Get the multiplier in Fq. - mc = fq.decodeClamped(mc) - - -- Multiply secret key members and add them together. - -- This unwraps into the "true" secret key times the multiplier (mod q). - local skmt = fq.num(0) - for i = 1, #sks do skmt = fq.add(skmt, fq.mul(sks[i], mc)) end + -- Reduce secret key using the multiplier. + local skmc = maddq.reduce(sks, fq.decodeClamped(mc)) -- Get bits. -- We have our exponent modulo q. We also know that its value is 0 modulo 8. -- Use the Chinese Remainder Theorem to find its value modulo 8q. - local bits = fq.bits(fq.mul(skmt, INV8Q)) + local bits = fq.bits(fq.mul(skmc, INV8Q)) return fp.encode(ladder8(fp.decode(pk), bits)) end