#include "Base64.h"


static const std::string base64_chars =
  "ABCDEFGHIJKLMNOPQRSTUVWXYZ"  "abcdefghijklmnopqrstuvwxyz"  "0123456789+/";


bool Base64::Encode(const std::vector<uint8_t>& input_data, std::string& encoded_string)
{
  std::string ret;
  unsigned char const* bytes_to_encode = input_data.data();
  unsigned int in_len = input_data.size();

  int i = 0;
  int j = 0;
  unsigned char char_array_3[3];
  unsigned char char_array_4[4];

  while (in_len--) {
    char_array_3[i++] = *(bytes_to_encode++);
    if (i == 3) {
      char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
      char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
      char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
      char_array_4[3] = char_array_3[2] & 0x3f;

      for(i = 0; (i <4) ; i++)
        ret += base64_chars[char_array_4[i]];
      i = 0;
    }
  }

  if (i)
  {
    for(j = i; j < 3; j++)
      char_array_3[j] = '\0';

    char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
    char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
    char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
    char_array_4[3] = char_array_3[2] & 0x3f;

    for (j = 0; (j < i + 1); j++)
      ret += base64_chars[char_array_4[j]];

    while((i++ < 3))
      ret += '=';

  }
  encoded_string = ret;
  return true;
}


bool Base64::Decode(const std::string& encoded_string, std::vector<uint8_t>& decoded_data)
{
  static const uint8_t masks[6] = { 0xFF, 0x0F, 0x03, 0x30, 0x3C, 0xFF };
  ssize_t in_len = encoded_string.size();
  ssize_t i = 0;
  ssize_t pos = 0;
  unsigned char char_array_4[4];
  decoded_data.clear(); // If just want to append to decoded_data, then don't clear here
  for (ssize_t in_ = 0; in_ < in_len && pos != std::string::npos; in_++)
  {
    pos = base64_chars.find(encoded_string[in_]);
    char_array_4[i++] = (pos != std::string::npos) ? pos : 0;
    if (i == 4 || pos == std::string::npos) {
      if (pos == std::string::npos)
        i--;
      for (int j = 0; j < (i-1); j++)
        decoded_data.push_back(((char_array_4[j] & masks[j]) << (2*(j+1)))
                             + ((char_array_4[j+1] & masks[j+3]) >> (2*(2-j))));
      i = 0;
    }
  }
  return true;
}


