Update Poly1305

Update Poly1305 to match the DJB approach for Fp arithmetic. This
improves performance and correctness. Also fix wrong output when passing
an empty string.
This commit is contained in:
Miguel Oliveira 2022-03-24 10:03:19 -03:00
parent 2668139d96
commit 85fb035641
No known key found for this signature in database
GPG key ID: 2C2BE789E1377025

View file

@ -6,8 +6,6 @@
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local random = require "ccryptolib.random" local random = require "ccryptolib.random"
local band = bit32.band
local mod = {} local mod = {}
--- Computes a Poly1305 message authentication code. --- Computes a Poly1305 message authentication code.
@ -23,116 +21,114 @@ function mod.mac(key, message)
-- Pad message. -- Pad message.
local pbplen = #message - 15 local pbplen = #message - 15
if #message % 16 ~= 0 then if #message % 16 ~= 0 or #message == 0 then
message = message .. "\1" message = message .. "\1"
message = message .. ("\0"):rep(-#message % 16) message = message .. ("\0"):rep(-#message % 16)
end end
-- Decode r. -- Decode r.
local r0, t1, r2, r3, t4, r5 = ("<I3I3I2I3I3I2"):unpack(key) local R0, R1, R2, R3 = ("<I4I4I4I4"):unpack(key)
-- Clamp and shift. -- Clamp and shift.
t1 = band(t1, 0xfffc0f) * 2 ^ 24 R0 = R0 % 2 ^ 28
r2 = band(r2, 0x000fff) * 2 ^ 48 R1 = (R1 - R1 % 4) % 2 ^ 28 * 2 ^ 32
r3 = band(r3, 0xfffffc) * 2 ^ 64 R2 = (R2 - R2 % 4) % 2 ^ 28 * 2 ^ 64
t4 = band(t4, 0xfffc0f) * 2 ^ 88 R3 = (R3 - R3 % 4) % 2 ^ 28 * 2 ^ 96
r5 = band(r5, 0x000fff) * 2 ^ 112
-- Split some words to fit. -- Split.
local r1 = t1 % 2 ^ 44 r2 = r2 + (t1 - r1) local r0 = R0 % 2 ^ 18 local r1 = R0 - r0
local r4 = t4 % 2 ^ 109 r5 = r5 + (t4 - r4) local r2 = R1 % 2 ^ 50 local r3 = R1 - r2
local r4 = R2 % 2 ^ 82 local r5 = R2 - r4
local r6 = R3 % 2 ^ 112 local r7 = R3 - r6
-- Generate scaled key.
local S1 = 5 / 2 ^ 130 * R1
local S2 = 5 / 2 ^ 130 * R2
local S3 = 5 / 2 ^ 130 * R3
-- Split.
local s2 = S1 % 2 ^ -80 local s3 = S1 - s2
local s4 = S2 % 2 ^ -48 local s5 = S2 - s4
local s6 = S3 % 2 ^ -16 local s7 = S3 - s6
local h0, h1, h2, h3, h4, h5, h6, h7 = 0, 0, 0, 0, 0, 0, 0, 0
-- Digest.
local h0, h1, h2, h3, h4, h5 = 0, 0, 0, 0, 0, 0
for i = 1, #message, 16 do for i = 1, #message, 16 do
-- Decode message block. -- Decode message block.
local m0, m1, m2, m3, m4, m5 = ("<I3I3I3I2I3I2"):unpack(message, i) local m0, m1, m2, m3 = ("<I4I4I4I4"):unpack(message, i)
-- Shift and add to accumulator. -- Shift message and add.
h0 = h0 + m0 local x0 = h0 + h1 + m0
h1 = h1 + m1 * 2 ^ 24 local x2 = h2 + h3 + m1 * 2 ^ 32
h2 = h2 + m2 * 2 ^ 48 local x4 = h4 + h5 + m2 * 2 ^ 64
h3 = h3 + m3 * 2 ^ 72 local x6 = h6 + h7 + m3 * 2 ^ 96
h4 = h4 + m4 * 2 ^ 88
h5 = h5 + m5 * 2 ^ 112
-- Apply per-block padding when applicable. -- Apply per-block padding when applicable.
if i <= pbplen then if i <= pbplen then x6 = x6 + 2 ^ 128 end
h5 = h5 + 2 ^ 128
-- Multiply
h0 = x0 * r0 + x2 * s6 + x4 * s4 + x6 * s2
h1 = x0 * r1 + x2 * s7 + x4 * s5 + x6 * s3
h2 = x0 * r2 + x2 * r0 + x4 * s6 + x6 * s4
h3 = x0 * r3 + x2 * r1 + x4 * s7 + x6 * s5
h4 = x0 * r4 + x2 * r2 + x4 * r0 + x6 * s6
h5 = x0 * r5 + x2 * r3 + x4 * r1 + x6 * s7
h6 = x0 * r6 + x2 * r4 + x4 * r2 + x6 * r0
h7 = x0 * r7 + x2 * r5 + x4 * r3 + x6 * r1
-- Carry.
local y0 = h0 + 3 * 2 ^ 69 - 3 * 2 ^ 69 h0 = h0 - y0 h1 = h1 + y0
local y1 = h1 + 3 * 2 ^ 83 - 3 * 2 ^ 83 h1 = h1 - y1 h2 = h2 + y1
local y2 = h2 + 3 * 2 ^ 101 - 3 * 2 ^ 101 h2 = h2 - y2 h3 = h3 + y2
local y3 = h3 + 3 * 2 ^ 115 - 3 * 2 ^ 115 h3 = h3 - y3 h4 = h4 + y3
local y4 = h4 + 3 * 2 ^ 133 - 3 * 2 ^ 133 h4 = h4 - y4 h5 = h5 + y4
local y5 = h5 + 3 * 2 ^ 147 - 3 * 2 ^ 147 h5 = h5 - y5 h6 = h6 + y5
local y6 = h6 + 3 * 2 ^ 163 - 3 * 2 ^ 163 h6 = h6 - y6 h7 = h7 + y6
local y7 = h7 + 3 * 2 ^ 181 - 3 * 2 ^ 181 h7 = h7 - y7
-- Reduce carry overflow into first limb.
h0 = h0 + 5 / 2 ^ 130 * y7
end end
-- Multiply accumulator by r. -- Carry canonically.
local g00 = h0 * r0 local c0 = h0 % 2 ^ 16 h1 = h0 - c0 + h1
local g01 = h1 * r0 + h0 * r1 local c1 = h1 % 2 ^ 32 h2 = h1 - c1 + h2
local g02 = h2 * r0 + h1 * r1 + h0 * r2 local c2 = h2 % 2 ^ 48 h3 = h2 - c2 + h3
local g03 = h3 * r0 + h2 * r1 + h1 * r2 + h0 * r3 local c3 = h3 % 2 ^ 64 h4 = h3 - c3 + h4
local g04 = h4 * r0 + h3 * r1 + h2 * r2 + h1 * r3 + h0 * r4 local c4 = h4 % 2 ^ 80 h5 = h4 - c4 + h5
local g05 = h5 * r0 + h4 * r1 + h3 * r2 + h2 * r3 + h1 * r4 + h0 * r5 local c5 = h5 % 2 ^ 96 h6 = h5 - c5 + h6
local g06 = h5 * r1 + h4 * r2 + h3 * r3 + h2 * r4 + h1 * r5 local c6 = h6 % 2 ^ 112 h7 = h6 - c6 + h7
local g07 = h5 * r2 + h4 * r3 + h3 * r4 + h2 * r5 local c7 = h7 % 2 ^ 130
local g08 = h5 * r3 + h4 * r4 + h3 * r5
local g09 = h5 * r4 + h4 * r5
local g10 = h5 * r5
-- Carry and reduce. -- Reduce carry overflow.
h5 = g05 % 2 ^ 130 g06 = g06 + (g05 - h5) h0 = c0 + 5 / 2 ^ 130 * (h7 - c7)
c0 = h0 % 2 ^ 16
g00 = g00 + g06 * (5 / 2 ^ 130) c1 = h0 - c0 + c1
g01 = g01 + g07 * (5 / 2 ^ 130)
g02 = g02 + g08 * (5 / 2 ^ 130)
g03 = g03 + g09 * (5 / 2 ^ 130)
g04 = g04 + g10 * (5 / 2 ^ 130)
h0 = g00 % 2 ^ 22 g01 = g01 + (g00 - h0)
h1 = g01 % 2 ^ 44 g02 = g02 + (g01 - h1)
h2 = g02 % 2 ^ 65 g03 = g03 + (g02 - h2)
h3 = g03 % 2 ^ 87 g04 = g04 + (g03 - h3)
h4 = g04 % 2 ^ 109 g05 = h5 + (g04 - h4)
h5 = g05 % 2 ^ 130 h0 = h0 + (g05 - h5) * (5 / 2 ^ 130)
end
-- Canonicalize. -- Canonicalize.
if h5 == (2 ^ 21 - 1) * 2 ^ 109 if c7 == 0xffff * 2 ^ 112
and h4 == (2 ^ 22 - 1) * 2 ^ 87 and c6 == 0xffff * 2 ^ 96
and h3 == (2 ^ 22 - 1) * 2 ^ 65 and c5 == 0xffff * 2 ^ 80
and h2 == (2 ^ 21 - 1) * 2 ^ 44 and c4 == 0xffff * 2 ^ 64
and h1 == (2 ^ 22 - 1) * 2 ^ 22 and c3 == 0xffff * 2 ^ 48
and h0 >= 2 ^ 22 - 5 and c2 == 0xffff * 2 ^ 32
and c1 == 0xffff * 2 ^ 16
and c0 >= 0xfffa
then then
h5 = 0 c7, c6, c5, c4, c3, c2, c1, c0 = 0, 0, 0, 0, 0, 0, 0, c0 - 0xfffa
h4 = 0
h3 = 0
h2 = 0
h1 = 0
h0 = h0 - (2 ^ 22 - 5)
end end
-- Decode s. -- Decode s.
local s0, s1, s2, s3, s4, s5 = ("<I3I3I3I2I3I2"):unpack(key, 17) local s0, s1, s2, s3 = ("<I4I4I4I4"):unpack(key, 17)
-- Add s and carry. -- Add.
h0 = h0 + s0 local t0 = s0 + c0 + c1 local u0 = t0 % 2 ^ 32
h1 = h1 + s1 * 2 ^ 24 local t1 = t0 - u0 + s1 * 2 ^ 32 + c2 + c3 local u1 = t1 % 2 ^ 64
h2 = h2 + s2 * 2 ^ 48 local t2 = t1 - u1 + s2 * 2 ^ 64 + c4 + c5 local u2 = t2 % 2 ^ 96
h3 = h3 + s3 * 2 ^ 72 local t3 = t2 - u2 + s3 * 2 ^ 96 + c6 + c7 local u3 = t3 % 2 ^ 128
h4 = h4 + s4 * 2 ^ 88
h5 = h5 + s5 * 2 ^ 112
local t0 = h0 % 2 ^ 16 h1 = h1 + (h0 - t0)
local t1 = h1 % 2 ^ 40 h2 = h2 + (h1 - t1)
local t2 = h2 % 2 ^ 64 h3 = h3 + (h2 - t2)
local t3 = h3 % 2 ^ 80 h4 = h4 + (h3 - t3)
local t4 = h4 % 2 ^ 104 h5 = h5 + (h4 - t4)
local t5 = h5 % 2 ^ 128
-- Encode. -- Encode.
t1 = t1 * 2 ^ -16 return ("<I4I4I4I4"):pack(u0, u1 / 2 ^ 32, u2 / 2 ^ 64, u3 / 2 ^ 96)
t2 = t2 * 2 ^ -40
t3 = t3 * 2 ^ -64
t4 = t4 * 2 ^ -80
t5 = t5 * 2 ^ -104
return ("<I2I3I3I2I3I3"):pack(t0, t1, t2, t3, t4, t5)
end end
local mac = mod.mac local mac = mod.mac