ccryptolib/chacha20.lua
2022-03-05 13:01:30 -03:00

126 lines
4.7 KiB
Lua

--- The ChaCha20 stream cipher.
--
-- @module chacha20
--
local expect = require "cc.expect".expect
local bxor = bit32.bxor
local rol = bit32.lrotate
local mod = {}
--- Encrypts/Decrypts data using ChaCha20.
--
-- @tparam string key A 32-byte random key.
-- @tparam string nonce A 12-byte per-message unique nonce.
-- @tparam string message A plaintext or ciphertext.
-- @tparam[opt=20] number rounds The number of ChaCha20 rounds to use.
-- @tparam[opt=1] number offset The block offset to generate the keystream at.
-- @treturn string The resulting ciphertext or plaintext.
--
function mod.crypt(key, nonce, message, rounds, offset)
expect(1, key, "string")
assert(#key == 32, "key length must be 32")
expect(2, nonce, "string")
assert(#nonce == 12, "nonce length must be 12")
expect(3, message, "string")
expect(4, rounds, "number", "nil")
rounds = rounds or 20
assert(rounds % 2 == 0, "round number must be even")
assert(rounds >= 8 and rounds <= 20, "round number must be in 8..20")
expect(5, offset, "number", "nil")
offset = offset or 1
assert(offset % 1 == 0 and offset >= 0, "offset must be an integer >= 0")
assert(#message + 64 * offset < 2 ^ 37, "offset too large")
-- Build the state block.
local i0, i1, i2, i3 = 0x61707865, 0x3320646e, 0x79622d32, 0x6b206574
local k0, k1, k2, k3, k4, k5, k6, k7 = ("<I4I4I4I4I4I4I4I4"):unpack(key)
local cr, n0, n1, n2 = offset, ("<I4I4I4"):unpack(nonce)
-- Pad the message.
local padded = message .. ("\0"):rep(-#message % 64)
-- Expand and combine.
local out = {}
local idx = 1
for i = 1, #padded / 64 do
-- Copy the block.
local s00, s01, s02, s03 = i0, i1, i2, i3
local s04, s05, s06, s07 = k0, k1, k2, k3
local s08, s09, s10, s11 = k4, k5, k6, k7
local s12, s13, s14, s15 = cr, n0, n1, n2
-- Iterate.
for _ = 1, rounds, 2 do
s00 = s00 + s04 s12 = rol(bxor(s12, s00), 16)
s08 = s08 + s12 s04 = rol(bxor(s04, s08), 12)
s00 = s00 + s04 s12 = rol(bxor(s12, s00), 8)
s08 = s08 + s12 s04 = rol(bxor(s04, s08), 7)
s01 = s01 + s05 s13 = rol(bxor(s13, s01), 16)
s09 = s09 + s13 s05 = rol(bxor(s05, s09), 12)
s01 = s01 + s05 s13 = rol(bxor(s13, s01), 8)
s09 = s09 + s13 s05 = rol(bxor(s05, s09), 7)
s02 = s02 + s06 s14 = rol(bxor(s14, s02), 16)
s10 = s10 + s14 s06 = rol(bxor(s06, s10), 12)
s02 = s02 + s06 s14 = rol(bxor(s14, s02), 8)
s10 = s10 + s14 s06 = rol(bxor(s06, s10), 7)
s03 = s03 + s07 s15 = rol(bxor(s15, s03), 16)
s11 = s11 + s15 s07 = rol(bxor(s07, s11), 12)
s03 = s03 + s07 s15 = rol(bxor(s15, s03), 8)
s11 = s11 + s15 s07 = rol(bxor(s07, s11), 7)
s00 = s00 + s05 s15 = rol(bxor(s15, s00), 16)
s10 = s10 + s15 s05 = rol(bxor(s05, s10), 12)
s00 = s00 + s05 s15 = rol(bxor(s15, s00), 8)
s10 = s10 + s15 s05 = rol(bxor(s05, s10), 7)
s01 = s01 + s06 s12 = rol(bxor(s12, s01), 16)
s11 = s11 + s12 s06 = rol(bxor(s06, s11), 12)
s01 = s01 + s06 s12 = rol(bxor(s12, s01), 8)
s11 = s11 + s12 s06 = rol(bxor(s06, s11), 7)
s02 = s02 + s07 s13 = rol(bxor(s13, s02), 16)
s08 = s08 + s13 s07 = rol(bxor(s07, s08), 12)
s02 = s02 + s07 s13 = rol(bxor(s13, s02), 8)
s08 = s08 + s13 s07 = rol(bxor(s07, s08), 7)
s03 = s03 + s04 s14 = rol(bxor(s14, s03), 16)
s09 = s09 + s14 s04 = rol(bxor(s04, s09), 12)
s03 = s03 + s04 s14 = rol(bxor(s14, s03), 8)
s09 = s09 + s14 s04 = rol(bxor(s04, s09), 7)
end
-- Decode message block.
local m00, m01, m02, m03, m04, m05, m06, m07
local m08, m09, m10, m11, m12, m13, m14, m15
m00, m01, m02, m03, m04, m05, m06, m07,
m08, m09, m10, m11, m12, m13, m14, m15, idx =
("<I4I4I4I4I4I4I4I4I4I4I4I4I4I4I4I4"):unpack(padded, idx)
-- Feed-forward and combine.
out[i] = ("<I4I4I4I4I4I4I4I4I4I4I4I4I4I4I4I4"):pack(
bxor(m00, s00 + i0), bxor(m01, s01 + i1),
bxor(m02, s02 + i2), bxor(m03, s03 + i3),
bxor(m04, s04 + k0), bxor(m05, s05 + k1),
bxor(m06, s06 + k2), bxor(m07, s07 + k3),
bxor(m08, s08 + k4), bxor(m09, s09 + k5),
bxor(m10, s10 + k6), bxor(m11, s11 + k7),
bxor(m12, s12 + cr), bxor(m13, s13 + n0),
bxor(m14, s14 + n1), bxor(m15, s15 + n2)
)
-- Increment counter.
cr = cr + 1
end
return table.concat(out):sub(1, #message)
end
return mod