Newer
Older
invertedlogic / InvertedLogic / iLUtilities / bignumclass.cpp
@John Ryland John Ryland on 10 Nov 2019 12 KB rename
//
// Big Number Class
// Written by John Ryland (C) Copyright 2015
//
#ifndef BIGNUM_CLASS_H
#define BIGNUM_CLASS_H

#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <vector>

/*

Based in part on work from here which is more C style version or such code:

Source:  http://dmitry.gr/index.php?r=05.Projects&proj=10.%20Shoving%20RSA%20into%20small%20places

License: you may use this code for any non-commercial purpose as long as you give me credit in your write-up/source. For commercial uses, contact me. One final note: trying this on something like an ATtiny, which lacks hardware multiplier will be very slow (you'll notice) but it will in fact work. Comments, suggestions, grievances, complaints, requests? [dmitrygr@gmail.com].

*/

class BigNumber
{
public:
  BigNumber()
  {
    Init("00");
  }

  BigNumber(const BigNumber& other) // Copy constructor
  {
	m_nbytes = other.m_nbytes;
	m_data = (uint8_t*)malloc(m_nbytes + 1);
    memcpy(m_data, other.m_data, m_nbytes + 1);
  }

  BigNumber(const char* str) // Create from a string
  {
    Init(str);
  }

  ~BigNumber()
  {
    free(m_data);
  }

  void operator=(const BigNumber& other) // Assignment operator
  {
    free(m_data);
	m_nbytes = other.m_nbytes;
	m_data = (uint8_t*)malloc(m_nbytes + 1);
    memcpy(m_data, other.m_data, m_nbytes + 1);
  }

  bool operator!() const // Is zero
  {
    size_t i = m_nbytes;
    do {
      if (m_data[i])
        return false;
    } while (i--);
    return true;
  }
  
  static int Compare(const BigNumber& a, const BigNumber& b)
  {
    // the two numbers don't have to be normalized
    size_t an = a.m_nbytes;
    size_t bn = b.m_nbytes;
    while (an > bn)
      if (a.m_data[an--]) // a > b
        return 1; // Checks if we have any high order values above the highest possible high order value in other
    while (an < bn)
      if (b.m_data[bn--]) // b > a
        return -1; // Reversed check, does other have any higher order values than us
    // now an == bn, we need to check more carefully
    do {
      if (a.m_data[an] > b.m_data[an])
        return 1; // a > b
      if (a.m_data[an] < b.m_data[an])
        return -1; // b > a
    } while (an--);
    return 0; // the two values are equal
  }

  bool operator==(const BigNumber& other) const
  {
    return (Compare(*this, other) == 0);
  }

  bool operator<(const BigNumber& other) const
  {
    return (Compare(*this, other) == -1);
  }

  bool operator>(const BigNumber& other) const
  {
    return (Compare(*this, other) == 1);
  }

  bool operator<=(const BigNumber& other) const
  {
    int result = Compare(*this, other);
    return (result == 0 || result == -1);
  }

  bool operator>=(const BigNumber& other) const
  {
    int result = Compare(*this, other);
    return (result == 0 || result == 1);
  }

  // a -= b. precondition: a >= b	[this is assumed and not checked]
  void operator-=(const BigNumber& other)
  {
    // TODO: check pre-condition
    uint8_t carry = 0; // carry / borrow flag
    uint8_t old;
    size_t idx = 0;
    do {
      // printf("before   -=   c = %i  m_data[idx] = %x  other.m_data[idx] = %x\n", carry, m_data[idx], other.m_data[idx]);

      // apply the carry from previous iteration
      old = m_data[idx];
      m_data[idx] -= carry;
      carry = 0;

      // if it wrapped over from applying the carry, we need to carry again for next iteration
      if (m_data[idx] > old)
        carry = 1;

      if (idx <= other.m_nbytes + 1) // TODO: check it is right, and optimize with two loops
      {
          // now apply the subtraction and carry if that wraps too
          old = m_data[idx];
          m_data[idx] -= other.m_data[idx]; // TODO: this goes off the end of other
          if (m_data[idx] > old)
            carry = 1; // TODO: ? perhaps carry++ ?
      }

      // printf("after   -=   c = %i  m_data[idx] = %x  other.m_data[idx] = %x\n", carry, m_data[idx], other.m_data[idx]);
    } while (idx++ != m_nbytes + 1);
  }

  // a += b
  void operator+=(const BigNumber& other)
  {
    uint8_t carry = 0; // carry / borrow flag
    uint8_t old;
    size_t idx = 0;
    do {
      /*
         uint32_t tmp = dst[i];
         tmp += c;
         tmp += buf[i];
         dst[i] = tmp & 0xFF;
         c = tmp >> 8;
       */
      // apply the carry from previous iteration
      old = m_data[idx];
      m_data[idx] += carry;
      carry = 0;

      // if it wrapped over from applying the carry, we need to carry again for next iteration
      if (old > m_data[idx])
        carry = 1;

      // now apply the subtraction and carry if that wraps too
      old = m_data[idx];
      m_data[idx] += other.m_data[idx];
      if (old > m_data[idx])
        carry = 1; // TODO: ? perhaps carry++ ?

    } while (idx++ != other.m_nbytes + 1);
  }

  void operator>>=(int shift)
  {
    // Because the data is stored in reverse endian order, the logical left shift is really shifting the data right
    int byteShift = 0;
    while (shift >= 8)
    {
      byteShift++;
      shift -= 8;
    }
/*
    m_data = (uint8_t*)realloc(m_data, m_nbytes + 1 + byteShift);
    for (int i = 0; i < byteShift; i++)
        m_data[m_nbytes + 1 + i] = 0; // Clear the expanded buffer
    m_nbytes += byteShift;
*/
    while (shift)
    {
      size_t idx = m_nbytes - byteShift;
      uint8_t overflow = 0;
      do {
        uint8_t last_overflow = overflow;
        overflow = (m_data[idx] & 1) ? 0x80 : 0;
        m_data[idx + byteShift] = (m_data[idx] >> 1) | last_overflow;
      } while (idx--);
      shift--;
    }
  }

  void operator<<=(int shift)
  {
    // Because the data is stored in reverse endian order, the logical left shift is really shifting the data right
    int byteShift = 0;
    while (shift >= 8)
    {
      byteShift++;
      shift -= 8;
    }
    // TODO: there is a bug here that the shift could be less than 8, and yet the data still needs expanding
    m_data = (uint8_t*)realloc(m_data, m_nbytes + 1 + byteShift);
    // move the bytes along by byteShift bytes
    for (int i = 0; i <= m_nbytes; i++)
        m_data[m_nbytes - i + byteShift] = m_data[m_nbytes - i];
    // Fill in with zero at the beginning (because of reverse order, beginning bytes are like the zeros in 10000)
    for (int i = 0; i < byteShift; i++)
        m_data[i] = 0;
    m_nbytes += byteShift;

    while (shift)
    {
      size_t idx = 0;
      uint8_t overflow = 0;
      do {
        uint8_t last_overflow = overflow;
        overflow = (m_data[idx] & 0x80) ? 1 : 0;
        m_data[idx] = (m_data[idx] << 1) | last_overflow;
      } while (idx++ <= m_nbytes + 1);
      shift--;
    }
  }

  //static char bnMod(BN* top, BN* bot)  // was destructive to bot
  void operator%=(const BigNumber& other) // however this is non-destructive to other, different semantics
  {
    BigNumber tmp(other);
    tmp.Normalize();
    if (tmp.m_nbytes <= m_nbytes)
    {
      size_t lsh = m_nbytes - tmp.m_nbytes + 1;
/*
      tmp.m_data = (uint8_t*)realloc(tmp.m_data, m_nbytes + 2);
      for (int i = 0; i < lsh; i++)
          tmp.m_data[tmp.m_nbytes + 1 + i] = 0; // Clear the expanded buffer
      tmp.m_nbytes = m_nbytes;
*/

/*
      printf("\nperforming modulus\n");
      printf("\n  a %% b where a = "); Print();
      printf("\n  and b = "); tmp.Print();
      printf("\n");
*/

      lsh <<= 3;
      tmp <<= lsh;
/*
      for (int i = 0; i < lsh; i++)
        tmp <<= 1;
      printf("\n  b << %i = ", int(lsh)); tmp.Print();
  */    

      lsh++;
      while (lsh--) {
        if (*this >= tmp) {
          *this -= tmp;
          //printf("\n now a is "); Print();
        }
        tmp >>= 1;
//        printf("\n now b is "); tmp.Print();
      }
      Normalize();	   // normalize the result
      // tmp.Normalize(); // fix any damage we did to "bot->nbytes"
    }
  }

  // numbers may be normalized but will otherwse be unmodified
  static BigNumber Multiply(const BigNumber& a, const BigNumber& b)
  {
    if (!(a >= b))
    {
//        printf(" swapping args\n");
//        printf(" a : "); a.Print();
//        printf(" b : "); b.Print();
        BigNumber r = Multiply(b, a);
//        printf(" r : "); r.Print();
//        printf("\n");
        return r;
    }
//        printf(" no-swap args\n");
//        printf(" a : "); a.Print();
//        printf(" b : "); b.Print();

    // Work out expected output size
    size_t i = a.m_nbytes + b.m_nbytes + 1;
//        printf(" i : %i \n", i);

    // Setup a destination value
    BigNumber d;
    free(d.m_data);
    d.m_nbytes = i;
    d.m_data = (uint8_t*)malloc((uint32_t)i + 1 + 1024);
    memset(d.m_data, 0, i + 1 + 1024 - 100);
//        printf(" d0 : "); d.Print();
//        printf("\n");

    //printf("mul output size: %i\n", (int)d.m_nbytes);
    uint8_t *dst = d.m_data;
    uint8_t *buf = (uint8_t*)malloc((uint32_t)b.m_nbytes + 2 + 1024);
    memset(buf, 0, b.m_nbytes + 2 );

    uint8_t c, o;
    size_t j = 0;
    do {
      i = 0;
      c = 0;
      //uint8_t c1 = 0;
      buf[0] = 0;
      // basically component wise  buf = a * b
      do {
        uint32_t v16 = ((uint32_t)a.m_data[j] * (uint32_t)b.m_data[i]) + (uint32_t)buf[i];
        //c1 = v16 >> 8;
        buf[i+0] = v16 & 0xFF;
if (v16 >= 0x10000)
printf("\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n!\n");
        buf[i+1] = v16 >> 8;
      } while (i++ != b.m_nbytes );
      i = 0;
	  //printf(".%i(%i).", (int)j, (int)a.m_nbytes );
      // TODO: this looks like 'add' operation
      // basically   dst += buf
      do {
/*
        uint32_t tmp = dst[i];
        tmp += c;
        tmp += buf[i];
        dst[i] = tmp & 0xFF;
        c = tmp >> 8;
        //tmp = tmp & 0xFF;
        //dst[i] += buf[i];
*/

        o = dst[i+j];
        dst[i+j] += c;
        c = 0;
        if (o > dst[i+j])
            c = 1;
        o = dst[i+j];
        dst[i+j] += buf[i];
        if (o > dst[i+j])
            c = 1; // TODO: should be c += 1 instead ?

      } while (i++ != b.m_nbytes + 1);
      //dst++;
    } while (j++ != a.m_nbytes );

    free(buf);
//        printf("\n d1 : "); d.Print();
    d.Normalize();
//        printf("\n d2 : "); d.Print();
    return d;
  }

  // big number multiply algorithms
  // https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication

  // base ^ exponent % modulus
  static BigNumber ExpMod(const BigNumber& base, const BigNumber& exponent, const BigNumber& modulus)
  {
    BigNumber dst("01");
    BigNumber a = base;
    BigNumber b = exponent;
    BigNumber t;
    a %= modulus;
    b %= modulus;
//    printf("\n now dst is "); dst.Print();
    while (!!b)
    {
      // what is happening to b looks correct and working as expected

      //printf("\n now exponent is "); b.Print();
      if (b.m_data[0] & 1)
      {
        //printf(" !bit! ");
//    printf("\n  dst is "); dst.Print();
//    printf("\n  a is "); a.Print();
//    printf("\n  m is "); modulus.Print();

        t = Multiply(dst, a);
//    printf("\n  t is "); t.Print();
        t %= modulus;
        dst = t;

//    printf("\n  now dst is "); dst.Print();
        //printf("\n now dst is "); dst.Print();
      }
      t = Multiply(a, a);
	  //printf(".\n");
      t %= modulus;
      a = t;

      b >>= 1;
    }
    //printf("\n now a is "); dst.Print();
    return dst;
  }

  void Normalize()
  {
    // Assumes at index 0 is the low order data, and at m_nbytes index is the high order data
    while (m_nbytes && !m_data[m_nbytes])
      m_nbytes--;
  }

  void Print() const
  {
    // This is printing in reverse order, low order to high order
	for (int i = 0; i <= m_nbytes; i++)
		printf("%02X ", m_data[m_nbytes - i]);
  }

  std::string ToString()
  {
    const char hexChrs[] = "0123456789ABCDEF";
    std::string str = "0x";
	for (int i = 0; i <= m_nbytes; i++)
    {
      str += hexChrs[(m_data[m_nbytes - i] >> 4) & 0xF];
      str += hexChrs[(m_data[m_nbytes - i] >> 0) & 0xF];
    }
    return str;
  }

  std::vector<uint8_t> Data()
  {
    //return std::vector<uint8_t>(m_data, m_data + m_nbytes);
    std::vector<uint8_t> r;
	for (int i = 0; i <= m_nbytes; i++)
        r.push_back(m_data[m_nbytes - i]);
    return r;
  }

private:
  void Init(const char* str)
  {
	int len = strlen(str);
	int byteLen = len / 2;
	m_nbytes = byteLen - 1; // ? implicitly +1 ?
	m_data = (uint8_t*)malloc((int)byteLen+2);
	for (int i = 0; i < byteLen; i++)
    {
      int val = 0;
      // TODO: how to tell if the endian order is right?
      val  = HexCharToInt(str[((len-2) - i * 2) + 1]);
      val |= HexCharToInt(str[((len-2) - i * 2)]) << 4;
      //val = hexCharToInt(str[(len - 1) - i * 2]);
      //val |= hexCharToInt(str[(len - 2) - i * 2]) << 4;
      m_data[i] = val; // This is assuming the string is in reverse order, low order to high order
    }
  }

  static unsigned int HexCharToInt(unsigned char ch)
  {
    unsigned int c = tolower(ch);
    return (c >= 'a') ? c - 'a' + 10 : c - '0';
  }

  size_t    m_nbytes; // num bytes used - 1   (for 2048-bit value this will be 255)
  uint8_t*  m_data;
};

#endif // BIGNUM_CLASS_H