// SPDX-License-Identifier: Unlicense

#include "base64.h"

#ifdef SHR3D_BASE64

#include <vector>

static const char* chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static const char padChar = '=';

std::string Base64::encode(const u8* in, const u64 len)
{
  std::string out;
  out.reserve(((len / 3) + (len % 3 > 0)) * 4);
  u32 temp;
  for (size_t i = 0; i < len / 3; ++i)
  {
    temp = (*in++) << 16;
    temp += (*in++) << 8;
    temp += (*in++);
    out += chars[(temp & 0x00FC0000) >> 18];
    out += chars[(temp & 0x0003F000) >> 12];
    out += chars[(temp & 0x00000FC0) >> 6];
    out += chars[(temp & 0x0000003F)];
  }
  switch (len % 3)
  {
  case 1:
    temp = (*in++) << 16;
    out += chars[(temp & 0x00FC0000) >> 18];
    out += chars[(temp & 0x0003F000) >> 12];
    out += padChar;
    out += padChar;
    break;
  case 2:
    temp = (*in++) << 16;
    temp += (*in++) << 8;
    out += chars[(temp & 0x00FC0000) >> 18];
    out += chars[(temp & 0x0003F000) >> 12];
    out += chars[(temp & 0x00000FC0) >> 6];
    out += padChar;
    break;
  }
  return out;
}

std::vector<u8> Base64::decode(const char* in, const u64 len)
{
  ASSERT(len % 4 == 0); // CRLF line endings? => change to unix LF

  size_t padding = 0;
  if (len)
  {
    if (in[len - 1] == padChar)
      padding++;
    if (in[len - 2] == padChar)
      padding++;
  }

  std::vector<u8> out;
  out.reserve(((len / 4) * 3) - padding);

  u32 temp = 0;
  u64 i = 0;
  while (i < len)
  {
    for (u64 quantumPosition = 0; quantumPosition < 4; ++quantumPosition)
    {
      temp <<= 6;
      if (in[i] >= 0x41 && in[i] <= 0x5A)
        temp |= in[i] - 0x41;
      else if (in[i] >= 0x61 && in[i] <= 0x7A)
        temp |= in[i] - 0x47;
      else if (in[i] >= 0x30 && in[i] <= 0x39)
        temp |= in[i] + 0x04;
      else if (in[i] == 0x2B)
        temp |= 0x3E;
      else if (in[i] == 0x2F)
        temp |= 0x3F;
      else if (in[i] == padChar)
      {
        switch (len - i)
        {
        case 1:
          out.push_back((temp >> 16) & 0x000000FF);
          out.push_back((temp >> 8) & 0x000000FF);
          return out;
        case 2:
          out.push_back((temp >> 10) & 0x000000FF);
          return out;
        default:
          unreachable();
        }
      }
      else
        ASSERT(false);
      i++;
    }
    out.push_back((temp >> 16) & 0x000000FF);
    out.push_back((temp >> 8) & 0x000000FF);
    out.push_back(temp & 0x000000FF);
  }
  return out;
}

#endif // SHR3D_BASE64
