/*
uint2048_t Class
Written by John Ryland
(C) Copyright 2015
*/
#include "Integer.h"
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <vector>
#include <string>
#include <sstream>
#include <iostream>
#include <iomanip>
#include "../../Util.h"
enum ComparisonResult
{
CR_Equal,
CR_LessThan,
CR_GreaterThan
};
struct uint2048_t::Pimpl
{
void Init(const char* str);
void Copy(const uint2048_t& other);
void Normalize();
size_t GetSize();
static ComparisonResult Compare(const uint2048_t& a, const uint2048_t& b);
static uint2048_t Add(const uint2048_t& a, const uint2048_t& b);
static uint2048_t Subtract(const uint2048_t& a, const uint2048_t& b);
static uint2048_t Multiply(const uint2048_t& a, const uint2048_t& b);
static uint2048_t Divide(const uint2048_t& a, const uint2048_t& b);
static uint2048_t Modulus(const uint2048_t& a, const uint2048_t& b);
static uint2048_t ShiftLeft(const uint2048_t& a, int shift);
static uint2048_t ShiftRight(const uint2048_t& a, int shift);
static uint2048_t Or(const uint2048_t& a, const uint2048_t& b);
static uint2048_t And(const uint2048_t& a, const uint2048_t& b);
static uint2048_t Xor(const uint2048_t& a, const uint2048_t& b);
static uint2048_t Not(const uint2048_t& a);
// Helper function to optimize the ExpMod function
static void MulMod(uint8_t* a, uint8_t* b, uint8_t* mod, size_t an, size_t bn, size_t mn, uint8_t* dest);
#define s_size 257 // Fixed at 2048 bits (256*8)
uint8_t m_data[s_size];
};
uint2048_t::uint2048_t() // Default constructor
: m_pimpl(std::make_unique<Pimpl>())
{
m_pimpl->Init("0");
}
uint2048_t::uint2048_t(const uint2048_t& other) // Copy constructor
: m_pimpl(std::make_unique<Pimpl>())
{
m_pimpl->Copy(other);
}
uint2048_t::uint2048_t(uint64_t a_number) // Create from a number
: m_pimpl(std::make_unique<Pimpl>())
{
std::ostringstream oss;
oss << std::hex << a_number;
m_pimpl->Init(oss.str().c_str());
//m_pimpl->Init(int2hex(a_number).c_str());
}
uint2048_t::uint2048_t(const char* str) // Create from a string
: m_pimpl(std::make_unique<Pimpl>())
{
m_pimpl->Init(str);
}
uint2048_t::~uint2048_t() // Destructor
{
}
void uint2048_t::operator=(const uint2048_t& other) // Assignment operator
{
m_pimpl->Copy(other);
}
bool uint2048_t::operator!() const // Is zero
{
return (Pimpl::Compare(*this, uint2048_t()) == CR_Equal);
}
bool uint2048_t::operator!=(const uint2048_t& other) const
{
return (Pimpl::Compare(*this, other) != CR_Equal);
}
bool uint2048_t::operator==(const uint2048_t& other) const
{
return (Pimpl::Compare(*this, other) == CR_Equal);
}
bool uint2048_t::operator<(const uint2048_t& other) const
{
return (Pimpl::Compare(*this, other) == CR_LessThan);
}
bool uint2048_t::operator>(const uint2048_t& other) const
{
return (Pimpl::Compare(*this, other) == CR_GreaterThan);
}
bool uint2048_t::operator<=(const uint2048_t& other) const
{
ComparisonResult result = Pimpl::Compare(*this, other);
return (result == CR_Equal || result == CR_LessThan);
}
bool uint2048_t::operator>=(const uint2048_t& other) const
{
ComparisonResult result = Pimpl::Compare(*this, other);
return (result == CR_Equal || result == CR_GreaterThan);
}
bool uint2048_t::operator&&(const uint2048_t& other) const
{
return (!!*this) && (!!other);
}
bool uint2048_t::operator||(const uint2048_t& other) const
{
return (!!*this) || (!!other);
}
uint2048_t& uint2048_t::operator--()
{
*this = Pimpl::Subtract(*this, uint2048_t(1));
return *this;
}
uint2048_t& uint2048_t::operator++()
{
*this = Pimpl::Add(*this, uint2048_t(1));
return *this;
}
uint2048_t uint2048_t::operator++(int)
{
uint2048_t tmp(*this);
operator++();
return tmp;
}
uint2048_t uint2048_t::operator--(int)
{
uint2048_t tmp(*this);
operator--();
return tmp;
}
uint2048_t& uint2048_t::operator-=(const uint2048_t& other)
{
*this = Pimpl::Subtract(*this, other);
return *this;
}
uint2048_t& uint2048_t::operator+=(const uint2048_t& other)
{
*this = Pimpl::Add(*this, other);
return *this;
}
uint2048_t& uint2048_t::operator*=(const uint2048_t& other)
{
*this = Pimpl::Multiply(*this, other);
return *this;
}
uint2048_t uint2048_t::operator-(const uint2048_t& other) const
{
return Pimpl::Subtract(*this, other);
}
uint2048_t uint2048_t::operator+(const uint2048_t& other) const
{
return Pimpl::Add(*this, other);
}
uint2048_t uint2048_t::operator*(const uint2048_t& other) const
{
return Pimpl::Multiply(*this, other);
}
uint2048_t uint2048_t::operator/(const uint2048_t& other) const
{
return Pimpl::Divide(*this, other);
}
uint2048_t uint2048_t::operator%(const uint2048_t& other) const
{
return Pimpl::Modulus(*this, other);
}
uint2048_t uint2048_t::operator&(const uint2048_t& other) const
{
return Pimpl::And(*this, other);
}
uint2048_t uint2048_t::operator|(const uint2048_t& other) const
{
return Pimpl::Or(*this, other);
}
uint2048_t uint2048_t::operator^(const uint2048_t& other) const
{
return Pimpl::Xor(*this, other);
}
uint2048_t uint2048_t::operator~() const
{
return Pimpl::Not(*this);
}
uint2048_t uint2048_t::operator>>(int shift) const
{
return Pimpl::ShiftRight(*this, shift);
}
uint2048_t uint2048_t::operator<<(int shift) const
{
return Pimpl::ShiftLeft(*this, shift);
}
uint2048_t& uint2048_t::operator>>=(int shift)
{
*this = Pimpl::ShiftRight(*this, shift);
return *this;
}
uint2048_t& uint2048_t::operator<<=(int shift)
{
*this = Pimpl::ShiftLeft(*this, shift);
return *this;
}
uint2048_t& uint2048_t::operator/=(const uint2048_t& other)
{
*this = Pimpl::Divide(*this, other);
return *this;
}
uint2048_t& uint2048_t::operator%=(const uint2048_t& other)
{
*this = Pimpl::Modulus(*this, other);
return *this;
}
uint2048_t& uint2048_t::operator&=(const uint2048_t& other)
{
*this = Pimpl::And(*this, other);
return *this;
}
uint2048_t& uint2048_t::operator|=(const uint2048_t& other)
{
*this = Pimpl::Or(*this, other);
return *this;
}
uint2048_t& uint2048_t::operator^=(const uint2048_t& other)
{
*this = Pimpl::Xor(*this, other);
return *this;
}
uint2048_t::operator uint64_t() const
{
return (uint64_t)strtoull(std::string(*this).c_str(), 0, 16);
}
void uint2048_t::Print() const
{
std::cout << std::string(*this);
}
// base ^ exponent % modulus
uint2048_t uint2048_t::ExpMod(const uint2048_t& base, const uint2048_t& exponent, const uint2048_t& modulus)
{
uint2048_t dst(1);
uint2048_t a = base % modulus;
uint2048_t b = exponent % modulus;
size_t mn = modulus.m_pimpl->GetSize();
const uint8_t *modBits = modulus.m_pimpl->m_data;
uint8_t *aBits = a.m_pimpl->m_data;
size_t an = a.m_pimpl->GetSize();
#if 1
// pre-computed modulus shifted by 0 to 7
uint8_t tmpModBits[s_size*16];
for (int shift = 0; shift < 8; shift++)
{
uint32_t overflow = 0;
uint8_t* mb = &tmpModBits[s_size*2*shift];
memset(mb, 0, s_size*2);
for (int i = 0; i < s_size; i++)
{
uint32_t var = modBits[i] << shift;
mb[i] = var | overflow;
overflow = var >> 8;
}
}
#endif
uint8_t tmp1Bits[s_size*2];
uint8_t *dstBits = dst.m_pimpl->m_data;
//printf("\n a: %s\n", std::string(a).c_str());
//printf("\n b: %s\n", std::string(b).c_str());
while (!!b) // while (bnbits)
{
if (b[0] & 1) // if (!!(b & BigInteger(1)))
{
//dst = (dst * a) % modulus;
memset(tmp1Bits, 0, s_size*2);
size_t dn = dst.m_pimpl->GetSize();
Pimpl::MulMod(dstBits, aBits, tmpModBits, dn, an, mn, tmp1Bits);
memcpy(dstBits, tmp1Bits, s_size);
}
//a = (a * a) % modulus;
memset(tmp1Bits, 0, s_size*2);
Pimpl::MulMod(aBits, aBits, tmpModBits, an, an, mn, tmp1Bits);
memcpy(aBits, tmp1Bits, s_size);
an = a.m_pimpl->GetSize();
b >>= 1; // perhaps could instead of shifting, test a particular bit instead of lowest bit
// bnbits--;
}
return dst;
}
uint2048_t::operator std::string() const
{
std::ostringstream oss;
oss << "0x";
size_t n = m_pimpl->GetSize();
for (size_t i = 0; i < n; i++)
oss << std::setfill('0') << std::setw(2) << std::hex << (int)m_pimpl->m_data[n - i - 1];
return oss.str();
}
size_t uint2048_t::Size() const
{
return (sizeof(m_pimpl->m_data) / sizeof(m_pimpl->m_data[0]));
}
uint8_t uint2048_t::operator[](size_t index) const
{
return m_pimpl->m_data[index];
}
////////////////// PRIVATE //////////////////////////////////////////////////
static unsigned int HexCharToInt(unsigned char ch)
{
unsigned int c = tolower(ch);
return (c >= 'a') ? c - 'a' + 10 : c - '0';
}
void uint2048_t::Pimpl::Init(const char* str)
{
int len = strlen(str);
memset(m_data, 0, s_size);
size_t idx = 0;
for (int i = 0; i < len; i++, idx++)
{
int val = HexCharToInt(str[len - 1 - i]);
i++;
if (i < len)
val |= HexCharToInt(str[len - 1 - i]) << 4;
m_data[idx] = val;
}
}
void uint2048_t::Pimpl::Copy(const uint2048_t& other)
{
memcpy(m_data, other.m_pimpl->m_data, sizeof(m_data));
}
void uint2048_t::Pimpl::Normalize()
{
}
size_t uint2048_t::Pimpl::GetSize()
{
size_t an;
for (an = s_size - 1; an; an--)
if (m_data[an])
break;
an++;
return an;
}
ComparisonResult uint2048_t::Pimpl::Compare(const uint2048_t& a, const uint2048_t& b)
{
size_t an = s_size - 1;
ComparisonResult result = CR_Equal;
do {
if (a[an] > b[an]) {
result = CR_GreaterThan; // a > b
break;
}
if (a[an] < b[an]) {
result = CR_LessThan; // b > a
break;
}
} while (an--);
/*
// This extra loop has no meaning except to try to show how we might reduce
// timing based side channel attacks by making the running time more consistent
int dummy = 0;
do {
if (a[an] > b[an]) {
dummy++;
}
if (a[an] < b[an]) {
dummy--;
}
} while (an--);
*/
return result; // the two values are equal
}
uint2048_t uint2048_t::Pimpl::Or(const uint2048_t& a, const uint2048_t& b)
{
uint2048_t result = a;
for (size_t idx = 0; idx < s_size; idx++)
result.m_pimpl->m_data[idx] |= b[idx];
return result;
}
uint2048_t uint2048_t::Pimpl::And(const uint2048_t& a, const uint2048_t& b)
{
uint2048_t result = a;
for (size_t idx = 0; idx < s_size; idx++)
result.m_pimpl->m_data[idx] &= b[idx];
return result;
}
uint2048_t uint2048_t::Pimpl::Xor(const uint2048_t& a, const uint2048_t& b)
{
uint2048_t result = a;
for (size_t idx = 0; idx < s_size; idx++)
result.m_pimpl->m_data[idx] ^= b[idx];
return result;
}
uint2048_t uint2048_t::Pimpl::Not(const uint2048_t& a)
{
uint2048_t result;
for (size_t idx = 0; idx < s_size; idx++)
result.m_pimpl->m_data[idx] = ~a[idx];
return result;
}
uint2048_t uint2048_t::Pimpl::Add(const uint2048_t& a, const uint2048_t& b)
{
uint2048_t result;
uint8_t carry = 0; // carry / borrow flag
for (size_t i = 0; i < s_size; i++)
{
// apply the carry from previous iteration
uint32_t tmp = a[i] + carry + b[i];
//if (i != s_size - 1)
result.m_pimpl->m_data[i] = tmp;
carry = tmp >> 8;
}
return result;
}
uint2048_t uint2048_t::Pimpl::Subtract(const uint2048_t& a, const uint2048_t& b)
{
uint2048_t result;
uint32_t borrow = 0; // carry / borrow flag
for (size_t i = 0; i < s_size; i++)
{
// apply the borrow from previous iteration
uint32_t tmp = a[i] - borrow - b[i];
result.m_pimpl->m_data[i] = tmp;
borrow = (tmp >> 8) & 1;
}
return result;
}
// big number multiply algorithms
// https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication
uint2048_t uint2048_t::Pimpl::Multiply(const uint2048_t& a, const uint2048_t& b)
{
//if (b > a) // Swap if a is not the larger value. We can now assume a is the larger value after
// return Multiply(b, a);
// Work out expected output size
size_t an = a.m_pimpl->GetSize();
size_t bn = b.m_pimpl->GetSize();
//if (bn > an)
// return Multiply(b, a);
/*
a= 1234 a[0] = 4, a[1] = 3 ...
b= x5678
-----
4x8 r[0+0] += (a[0] x b[0]) % 10
+4x7x10 r[0+1] += (a[0] x b[0]) / 10 + (a[0] x b[1]) % 10
+4x6x100 r[0+2] += (a[0] x b[1]) / 10 + (a[0] x b[2]) % 10
+4x5x1000 r[0+3] += (a[0] x b[2]) / 10 + (a[0] x b[3]) % 10
+3x8x10 r[1+0] += (a[1] x b[0]) % 10
+3x7x100 r[1+1] += (a[1] x b[0]) / 10 + (a[1] x b[1]) % 10
+3x6x1000
+3x5x10000
-----
*/
// Setup a destination value
uint2048_t d;
uint8_t dest[s_size*2];
memset(dest, 0, s_size*2);
for (size_t j = 0; j < an; j++)
{
// basically component wise matrix of a * b
uint32_t c1 = 0, c2 = 0;
for (size_t i = 0; i < bn; i++)
{
uint32_t t1 = ((uint32_t)a[j] * (uint32_t)b[i]) + (uint32_t)c1;
uint32_t t2 = (uint32_t)dest[i+j] + c2 + (t1 & 0xFF);
c1 = t1 >> 8;
c2 = t2 >> 8;
dest[i+j] = t2;
}
//dest[bn+j] = dest[bn+j] + c1 + c2;
dest[bn+j] += c1 + c2;
}
memcpy(d.m_pimpl->m_data, dest, s_size);
return d;
}
uint2048_t uint2048_t::Pimpl::Divide(const uint2048_t& a, const uint2048_t& b)
{
size_t an = a.m_pimpl->GetSize() - 1;
size_t bn = b.m_pimpl->GetSize() - 1;
if (bn > an)
return uint2048_t();
size_t bitsLarger = (an - bn + 1) << 3;
//printf(" bl: %i \n", bitsLarger);
uint2048_t bit = uint2048_t(1) << bitsLarger;
//printf(" bit: %s \n", std::string(bit).c_str());
uint2048_t tmp = b << bitsLarger;
//printf(" tmp: %s \n", std::string(tmp).c_str());
uint2048_t mod = a;
uint2048_t res = uint2048_t();
do {
if (mod >= tmp) {
mod -= tmp;
res += bit;
}
tmp >>= 1;
bit >>= 1;
} while (bitsLarger--);
return res;
}
uint2048_t uint2048_t::Pimpl::Modulus(const uint2048_t& a, const uint2048_t& b)
{
size_t an = a.m_pimpl->GetSize() - 1;
size_t bn = b.m_pimpl->GetSize() - 1;
//printf(" an: %i bn: %i \n", an, bn);
//printf("\n");
if (bn > an)
{
//printf("\n bn > an an: %i bn: %i \n", an, bn);
return a;
}
if (bn == an)
if (b > a)
{
//printf("\n b > a an: %i bn: %i \n", an, bn);
//printf(" a: %s \n", std::string(a).c_str());
//printf(" b: %s \n\n", std::string(b).c_str());
return a;
}
size_t bitsLarger = (an - bn + 1) << 3;
//printf(" bl: %i \n", bitsLarger);
uint2048_t tmp = b;
tmp <<= bitsLarger;
//printf(" tmp: %s \n", std::string(tmp).c_str());
uint2048_t mod = a;
do {
if (mod >= tmp) {
mod -= tmp;
//printf(" mod: %s \n", std::string(mod).c_str());
}
tmp >>= 1;
//printf(" tmp: %s \n", std::string(tmp).c_str());
} while (bitsLarger--);
//printf(" mod: %s \n", std::string(mod).c_str());
return mod;
}
void uint2048_t::Pimpl::MulMod(uint8_t* a, uint8_t* b, uint8_t* mod, size_t an, size_t bn, size_t mn, uint8_t* dest)
//uint2048_t uint2048_t::Pimpl::MulMod(const uint2048_t& a, const uint2048_t& b, const uint2048_t& mod)
{
// This is where we are doing the multiplication in to dest (which is double width)
for (size_t j = 0; j < an; j++)
{
// basically component wise matrix of a * b
uint32_t c1 = 0, c2 = 0;
for (size_t i = 0; i < bn; i++)
{
uint32_t t1 = ((uint32_t)a[j] * (uint32_t)b[i]) + (uint32_t)c1;
uint32_t t2 = (uint32_t)dest[i+j] + c2 + (t1 & 0xFF);
c1 = t1 >> 8;
c2 = t2 >> 8;
dest[i+j] = t2;
}
dest[bn+j] += c1 + c2;
}
// 'an' is now the size of dest
an = an + bn;
uint8_t* dstPtr = dest;
// check the magnitude of modulus is about or less than dest (otherwise we can just return dest)
if (mn <= an)
{
// This is where we are doing the modulus of dest % mod
size_t byteShift = an - mn + 1;
size_t bitsLarger = byteShift << 3;
// siz is the byte width of the numbers we need to deal with
const size_t siz = s_size + byteShift;
// shift 'mod' in to place directly in to 'tmp' (eg: tmp = mod << byteShift)
//memcpy(&tmp[byteShift], mod, s_size);
// here we are doing the shift 'virtually' in that we just keep track of the shift amount
// and then fix up below as applicable
const uint8_t* tmpPtr = mod;
size_t byt = byteShift;
do {
size_t idx = siz - 1;
// we are checking if dest >= tmp. We exit the loop when they stop being equal as we scan from high
// order bytes to low order ones. When they stop being equal (or we get to the end of the array)
// the comparison of the last value we were at will be the comparison result for the entire value
for (idx = siz - 1; idx >= byt; idx--)
if (dest[idx] != tmpPtr[idx-byt]) // here 'tmp' is indexed in a way as if it was shifted by byt bytes
break;
if (idx < byt)
for (idx = byt - 1; idx; idx--)
if (dest[idx]) // here we are in the area of the virtual 'tmp' that was shifted out to be filled with virtual 0s so this collapsed to just be checking if dest[idx] != 0
break;
/*
while (idx && dest[idx] == ((idx >= byt) ? tmpPtr[idx-byt] : 0) ) // Hot code path
{
idx--;
}
*/
if (dest[idx] >= ((idx >= byt) ? tmpPtr[idx-byt] : 0) )
{
// here we subtract 'tmp' from 'dest' (eg: dest -= tmp)
uint32_t borrow = 0; // carry / borrow flag
//for (size_t i = 0; i < siz; i++)
for (size_t i = byt; i < siz; i++)
{
// apply the borrow from previous iteration
// uint32_t x = dest[i] - borrow - ((i >= byt) ? tmpPtr[i-byt] : 0); // tmp[i]; // Hot
/*
uint32_t x = dest[i] - borrow - tmpPtr[i-byt]; // Hot code path
dest[i] = x;
borrow = (x >> 8) & 1;
*/
uint32_t t = tmpPtr[i-byt]; // OOO
uint32_t sub = borrow + t;
uint32_t x = dest[i]; // OOO
x = x - sub;
dest[i] = x; // OO
x >>= 8;
x &= 1;
borrow = x;// (x >> 8) & 1;
}
}
if (bitsLarger--)
{
// 'tmp' = 'mod' << bitsLarger
// but we just update the virtual tmp values (mod has precomputed shifts from 0-7)
tmpPtr = mod + s_size*2*(bitsLarger & 7);
byt = bitsLarger >> 3;
}
bitsLarger++;
} while (bitsLarger--);
}
}
uint2048_t uint2048_t::Pimpl::ShiftLeft(const uint2048_t& a, int shift)
{
// Because the data is stored in reverse endian order, the left shift is really shuffling the bytes right
// Within the bytes we still shift left
uint2048_t res;
size_t byteShift = shift >> 3;
uint8_t overflow = 0;
shift &= 7;
for (int i = byteShift; i < s_size; i++)
{
uint32_t var = a[i - byteShift] << shift;
res.m_pimpl->m_data[i] = var | overflow;
overflow = var >> 8;
}
return res;
}
uint2048_t uint2048_t::Pimpl::ShiftRight(const uint2048_t& a, int shift)
{
// Because the data is stored in reverse endian order, the right shift is really shuffling the bytes left
// Within bytes we still shift right
// Shift right is easier because we don't need to change any buffer sizes
// TODO: check we are zeroing out properly, or normalizing which might be easier
uint2048_t res;
size_t byteShift = shift >> 3;
uint8_t overflow = 0;
shift &= 7;
//for (size_t i = s_size - 1; i; i--)
for (size_t i = byteShift; i < s_size - 1; i++)
{
uint8_t var = a[i] >> shift;
overflow = a[i+1] << (8 - shift);
res.m_pimpl->m_data[i-byteShift] = var | overflow;
}
res.m_pimpl->m_data[s_size-1-byteShift] = a[s_size-1] >> shift;
return res;
}