//  BlockyFroggy
//  Copyright © 2017 John Ryland.
//  All rights reserved.
#include "Maths.h"


namespace details
{

  inline float e(const float *m, int j, int i)
  {
    return m[(j % 4) * 4 + (i % 4)];
  }

  inline float invf(int i, int j, const float *m)
  {
    int o = 2 + (j - i);
    i += 4 + o;
    j += 4 - o;
    float inv =
      + e(m, j-1, i+1) * e(m, j+0, i+0) * e(m, j+1, i-1)
      + e(m, j+1, i+1) * e(m, j-1, i+0) * e(m, j+0, i-1)
      + e(m, j-1, i-1) * e(m, j+0, i+1) * e(m, j+1, i+0)
      - e(m, j-1, i-1) * e(m, j+0, i+0) * e(m, j+1, i+1)
      - e(m, j+1, i-1) * e(m, j-1, i+0) * e(m, j+0, i+1)
      - e(m, j-1, i+1) * e(m, j+0, i-1) * e(m, j+1, i+0);
    return (o % 2) ? inv : -inv;
  }

  bool inverseMatrix4x4(float *restrict out, const float *restrict m)
  {
    float inv[16];
    for (int i = 0; i < 4; i++)
      for (int j = 0; j < 4; j++)
        inv[j*4+i] = invf(i,j,m);
    double determinant = 0;
    for (int k = 0; k < 4; k++)
      determinant += m[k] * inv[k*4];
    if (determinant == 0)
      return false;
    float invdet = float(1.0 / determinant);
    for (int i = 0; i < 16; i++)
      out[i] = inv[i] * invdet;
    return true;
  }

  bool matrix4x4ToNormalMatrix3x3(float *restrict out, const float *restrict m)
  {
    double determinant = 
             e(m,0,0) * ( e(m,1,1)*e(m,2,2) - e(m,2,1)*e(m,1,2) )
           - e(m,0,1) * ( e(m,1,0)*e(m,2,2) - e(m,1,2)*e(m,2,0) )
           + e(m,0,2) * ( e(m,1,0)*e(m,2,1) - e(m,1,1)*e(m,2,0) );
    if (determinant == 0)
      return false;
    float invdet = float(1.0 / determinant);
    out[0] =  (e(m,1,1)*e(m,2,2)-e(m,2,1)*e(m,1,2))*invdet;
    out[1] = -(e(m,0,1)*e(m,2,2)-e(m,0,2)*e(m,2,1))*invdet;
    out[2] =  (e(m,0,1)*e(m,1,2)-e(m,0,2)*e(m,1,1))*invdet;
    out[3] = -(e(m,1,0)*e(m,2,2)-e(m,1,2)*e(m,2,0))*invdet;
    out[4] =  (e(m,0,0)*e(m,2,2)-e(m,0,2)*e(m,2,0))*invdet;
    out[5] = -(e(m,0,0)*e(m,1,2)-e(m,1,0)*e(m,0,2))*invdet;
    out[6] =  (e(m,1,0)*e(m,2,1)-e(m,2,0)*e(m,1,1))*invdet;
    out[7] = -(e(m,0,0)*e(m,2,1)-e(m,2,0)*e(m,0,1))*invdet;
    out[8] =  (e(m,0,0)*e(m,1,1)-e(m,1,0)*e(m,0,1))*invdet;
    return true;
  }

  void setMatrix4x4(float* out, 
      float m11, float m12, float m13, float m14,
      float m21, float m22, float m23, float m24,
      float m31, float m32, float m33, float m34,
      float m41, float m42, float m43, float m44)
  {
    out[0]  = m11; out[1]  = m12; out[2]  = m13; out[3]  = m14;
    out[4]  = m21; out[5]  = m22; out[6]  = m23; out[7]  = m24;
    out[8]  = m31; out[9]  = m32; out[10] = m33; out[11] = m34;
    out[12] = m41; out[13] = m42; out[14] = m43; out[15] = m44;
  }

}


namespace Math
{

  float dotProduct(const float *v1, const float *v2, int stride1, int stride2, int count)
  {
    float result = 0;
    for (int i = 0; i < count; i++)
      result += v1[i * stride1] * v2[i * stride2];
    return result;
  }

  void crossProduct(float *restrict out, const float *restrict v1, const float *restrict v2)
  {
    for (int x = 0; x < 3; x++)
      out[x] = v1[(x+1)%3] * v2[(x+2)%3] - v1[(x+2)%3] * v2[(x+1)%3];
  }

  void transformVector(float *restrict vout, const float *restrict m, const float *restrict vin)
  {
    for (int i = 0; i < 4; i++)
      vout[i] = dotProduct(m+i, vin, 4);
  }

  void normalizeVector(float *vInOut, int count)
  {
    float invMag = 1.0f / sqrt(dotProduct(vInOut, vInOut, 1, 1, 3));
    for (int i = 0; i < count; i++)
      vInOut[i] *= invMag;
  }

  void makeIdentityMatrix4x4(float* a_out)
  {
    for (int j = 0; j < 16; j++)
      a_out[j] = 0;
    a_out[0] = a_out[5] = a_out[10] = a_out[15] = 1;
  }

  void makePerspectiveMatrix4x4(float* a_out, float a_fov, float a_aspect, float a_near, float a_far)
  {
    float ctan = 1.0f / tanf(a_fov / 2.0f);
    float inlf = 1.0f / (a_near - a_far);
    details::setMatrix4x4(a_out, ctan / a_aspect, 0.0f, 0.0f, 0.0f, 0.0f, ctan, 0.0f, 0.0f,
        0.0f, 0.0f, (a_far + a_near) * inlf, -1.0f, 0.0f, 0.0f, (2.0f * a_far * a_near) * inlf, 0.0f);
  }

  void makeOrthographicMatrix4x4(float* a_out, float a_left, float a_right, float a_bottom, float a_top, float a_near, float a_far)
  {
    float ral = a_right + a_left;
    float tab = a_top + a_bottom;
    float fan = a_far + a_near;
    float rsl = 1.0f / (a_right - a_left);
    float tsb = 1.0f / (a_top - a_bottom);
    float fsn = 1.0f / (a_far - a_near);
    details::setMatrix4x4(a_out, 2.0f * rsl, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f * tsb, 0.0f, 0.0f,
        0.0f, 0.0f, -2.0f * fsn, 0.0f, -ral * rsl, -tab * tsb, -fan * fsn, 1.0f);
  }

  void quatToMatrix4x4(float *restrict out, const float *restrict q)
  {
    float x[10] = { q[0]*q[0], q[1]*q[1], q[2]*q[2], q[3]*q[3], q[0]*q[1],
                    q[2]*q[3], q[0]*q[2], q[1]*q[3], q[1]*q[2], q[0]*q[3] };
    float invs = 1.0f / (x[0] + x[1] + x[2] + x[3]); // inverse square length only required if quaternion not normalised
    details::setMatrix4x4(out,
      (x[0] - x[1] - x[2] + x[3]) * invs, 2.0f * (x[4] - x[5]) * invs, 2.0f * (x[6] + x[7]) * invs, 0.0f,
      2.0f * (x[4] + x[5]) * invs, (-x[0] + x[1] - x[2] + x[3]) * invs, 2.0f * (x[8] - x[9]) * invs, 0.0f,
      2.0f * (x[6] - x[7]) * invs, 2.0f * (x[8] + x[9]) * invs, (-x[0] - x[1] + x[2] + x[3]) * invs, 0.0f,
      0.0f, 0.0f, 0.0f, 1.0f);
  }

  void translationRotationScaleToMatrix4x4(float *restrict out, const float *restrict translate, const float *restrict rot, float uniformScale)
  {
    makeIdentityMatrix4x4(out);
    translateMatrix4x4(out, translate);
    rotateMatrix4x4(out, rot);
    scaleMatrix4x4(out, uniformScale, uniformScale, uniformScale);
  }

  void translateMatrix4x4(float *restrict a_inOut, const float *restrict vec3)
  {
    // similarities to transformVector
    for (int i = 0; i < 4; i++)
      a_inOut[12+i] += dotProduct(a_inOut+i, vec3, 4, 1, 3);
  }

  void translateMatrix4x4(float* a_inOut, float x, float y, float z)
  {
    float xyz[3] = { x, y, z };
    translateMatrix4x4(a_inOut, xyz);
  }

  void scaleMatrix4x4(float *restrict a_inOut, const float *restrict vec3)
  {
    for (int j = 0; j < 3; j++)
      for (int i = 0; i < 4; i++)
        a_inOut[j*4+i] *= vec3[j];
  }

  void scaleMatrix4x4(float* a_inOut, float x, float y, float z)
  {
    float xyz[3] = { x, y, z };
    scaleMatrix4x4(a_inOut, xyz);
  }

  void rotateMatrix4x4(float *restrict inOut, const float *restrict vec3)
  {
    rotateAxisMatrix4x4(inOut, degreesToRadians(vec3[0]), 1.0f, 0.0f, 0.0f);
    rotateAxisMatrix4x4(inOut, degreesToRadians(vec3[1]), 0.0f, 1.0f, 0.0f);
    rotateAxisMatrix4x4(inOut, degreesToRadians(vec3[2]), 0.0f, 0.0f, 1.0f);
  }

  void rotateMatrix4x4(float* inOut, float x, float y, float z)
  {
    float xyz[3] = { x, y, z };
    rotateMatrix4x4(inOut, xyz);
  }

  void rotateAxisMatrix4x4(float *restrict a_inOut, float radians, const float *restrict vec3)
  {
    float v[3] = { vec3[0], vec3[1], vec3[2] };
    normalizeVector(v, 3);
    float cos = cosf(radians);
    float cosp = 1.0f - cos;
    float sin = sinf(radians);
    float rotm[16];
    details::setMatrix4x4(rotm,
       cos + cosp * v[0] * v[0],         cosp * v[0] * v[1] + v[2] * sin,   cosp * v[0] * v[2] - v[1] * sin,    0.0f,
       cosp * v[0] * v[1] - v[2] * sin,  cos + cosp * v[1] * v[1],          cosp * v[1] * v[2] + v[0] * sin,    0.0f,
       cosp * v[0] * v[2] + v[1] * sin,  cosp * v[1] * v[2] - v[0] * sin,   cos + cosp * v[2] * v[2],           0.0f,
       0.0f, 0.0f, 0.0f, 1.0f);
    float m[16];
    for (int i = 0; i < 16; i++)
      m[i] = a_inOut[i];
    multiplyMatrix4x4(a_inOut, m, rotm); // is it possible to do the multiply in-place?
  }

  void rotateAxisMatrix4x4(float* a_inOut, float radians, float x, float y, float z)
  {
    float xyz[3] = { x, y, z };
    rotateAxisMatrix4x4(a_inOut, radians, xyz);
  }

  void quatRotateMatrix4x4(float *restrict a_inOut, const float *restrict a_quat)
  {
    float rotm[16];
    float m[16];
    quatToMatrix4x4(rotm, a_quat);
    for (int i = 0; i < 16; i++)
      m[i] = a_inOut[i];
    multiplyMatrix4x4(a_inOut, m, rotm);
  }

  // Helpers
  void multiplyMatrix4x4(float *restrict out, const float *restrict m1, const float *restrict m2)
  {
    for (int j = 0; j < 4; j++)
      for (int k = 0; k < 4; k++)
        out[j+k*4] = dotProduct(m1+j, m2+k*4, 4);
  }

  bool matrix4x4ToNormalMatrix3x3(float *restrict out, const float *restrict m)
  {
    return details::matrix4x4ToNormalMatrix3x3(out, m);
  }

  bool inverseMatrix4x4(float *restrict out, const float *restrict m)
  {
    return details::inverseMatrix4x4(out, m);
  }

  void transposeMatrix4x4(float *restrict out, const float *restrict m)
  {
    details::setMatrix4x4(out, m[15], m[14], m[13], m[12], m[11], m[10],
       m[9],  m[8], m[7],  m[6],  m[5],  m[4], m[3],  m[2],  m[1],  m[0]);
  }
  
  float degreesToRadians(float degrees)
  {
    const float deg2rad = 0.01745329251f;
    return degrees * deg2rad;
  }
  
  float radiansToDegrees(float radians)
  {
    const float rad2deg = 57.2957795131f;
    return radians * rad2deg;
  }

  void makeCameraMatrix(float *restrict out, float fov, float aspect, const float restrict translate[3], const float restrict rotate[3], float scale, bool ortho)
  {
    float projectionMatrix[16];
    float modelViewMatrix[16];
    if (ortho)
      makeOrthographicMatrix4x4(projectionMatrix, -2.0f, 2.0f, -2.0f, 2.0f, 0.0f, 10.0f);
    else
      makePerspectiveMatrix4x4(projectionMatrix, degreesToRadians(fov), aspect, 0.1f, 100.0f);
    translationRotationScaleToMatrix4x4(modelViewMatrix, translate, rotate, scale);
    multiplyMatrix4x4(out, projectionMatrix, modelViewMatrix);
  }

  //
  // This is cheap way to make the clip planes from the mvp without needing to
  // do expensive matrix inversions and other maths. This quite directly does the job.
  //
  // References:
  //      http://gamedevs.org/uploads/fast-extraction-viewing-frustum-planes-from-world-view-projection-matrix.pdf
  //      http://www.lighthouse3d.com/tutorials/view-frustum-culling/clip-space-approach-implementation-details/
  //
  void extractClipPlanes(float restrict clipPlanes[6][4], const float *restrict m)
  {
    for (int i = 0; i < 4; i++)
    {
      clipPlanes[0][i] = m[3+i*4] + m[0+i*4]; // Left clipping plane
      clipPlanes[1][i] = m[3+i*4] - m[0+i*4]; // Right clipping plane
      clipPlanes[2][i] = m[3+i*4] - m[1+i*4]; // Top clipping plane
      clipPlanes[3][i] = m[3+i*4] + m[1+i*4]; // Bottom clipping plane
      clipPlanes[4][i] = m[3+i*4] + m[2+i*4]; // Near clipping plane
      clipPlanes[5][i] = m[3+i*4] - m[2+i*4]; // Far clipping plane
    }
    // Normalize the plane equations, if required
    for (int i = 0; i < 6; i++)
      normalizeVector(clipPlanes[i]);
  }

  bool isClipped(float clipPlanes[6][4], float x, float y, float z, float w, float h, float d)
  {
    if (x <= w || x >= (1000-w))
      return true;
    // Test specific
    float maxDist = (w < h) ? h : ((w < d) ? d : w); // could pre-compute this once
    maxDist *= 1.3; // Hack to fix things getting clipped too early on the edges of the screen
    float pnt[4] = { x, y, z, 1.0 };
    for (int i = 0; i < 6; i++)
      if (-maxDist > dotProduct(pnt, clipPlanes[i], 1, 1, 4))
        return true;
    return false;
  }
  
  void calcFrustum(float restrict frustumPts[8][4], const float *restrict m)
  {
    // Kind of expensive with the matrix inverse, so better if just use for debugging
    float inv[16];
    inverseMatrix4x4(inv, m);
    for (int i = 0; i < 8; i++)
    {
      float w[4] = { (i&1)?1.0f:-1.0f, (i&2)?1.0f:-1.0f, (i&4)?0.95f:-0.94f, 1.0f };
      for (int j = 0; j < 3; j++)
        w[j] *= 1.01f;
      transformVector(frustumPts[i], inv, w);
      for (int j = 0; j < 3; j++)
        frustumPts[i][j] /= frustumPts[i][3];
    }
  }

  bool compactInvertMatrix(const float m[16], float invOut[16])
  {
    float inv[16], det;
    int i;
    int idx[16][18] = {
      {  5, 10, 15, 9, 7, 14, 13, 6, 11, 5, 11, 14, 9, 6, 15, 13, 7, 10 },
      {  1, 11, 14, 9, 2, 15, 13, 3, 10, 1, 10, 15, 9, 3, 14, 13, 2, 11 },
      {  1, 6, 15, 5, 3, 14, 13, 2, 7, 1, 7, 14, 5, 2, 15, 13, 3, 6 },
      {  1, 7, 10, 5, 2, 11, 9, 3, 6, 1, 6, 11, 5, 3, 10, 9, 2, 7 },
      {  4, 11, 14, 8, 6, 15, 12, 7, 10, 4, 10, 15, 8, 7, 14, 12, 6, 11 },
      {  0, 10, 15, 8, 3, 14, 12, 2, 11, 0, 11, 14, 8, 2, 15, 12, 3, 10 },
      {  0, 7, 14, 4, 2, 15, 12, 3, 6, 0, 6, 15, 4, 3, 14, 12, 2, 7 },
      {  0, 6, 11, 4, 3, 10, 8, 2, 7, 0, 7, 10, 4, 2, 11, 8, 3, 6 },
      {  4, 9, 15, 8, 7, 13, 12, 5, 11, 4, 11, 13, 8, 5, 15, 12, 7, 9 },
      {  0, 11, 13, 8, 1, 15, 12, 3, 9, 0, 9, 15, 8, 3, 13, 12, 1, 11 },
      {  0, 5, 15, 4, 3, 13, 12, 1, 7, 0, 7, 13, 4, 1, 15, 12, 3, 5 },
      {  0, 7, 9, 4, 1, 11, 8, 3, 5, 0, 5, 11, 4, 3, 9, 8, 1, 7 },
      {  4, 10, 13, 8, 5, 14, 12, 6, 9, 4, 9, 14, 8, 6, 13, 12, 5, 10 },
      {  0, 9, 14, 8, 2, 13, 12, 1, 10, 0, 10, 13, 8, 1, 14, 12, 2, 9 },
      {  0, 6, 13, 4, 1, 14, 12, 2, 5, 0, 5, 14, 4, 2, 13, 12, 1, 6 },
      {  0, 5, 10, 4, 2, 9, 8, 1, 6, 0, 6, 9, 4, 1, 10, 8, 2, 5 }
    };
    for (i = 0; i < 16; i++)
    {
      inv[i] = 0.0f;
      for (int j = 0; j < 6; j++)
        inv[i] += (j >= 3 ? -1.0f : 1.0f) * m[idx[i][j*3+0]] * m[idx[i][j*3+1]] * m[idx[i][j*3+2]];
    }
    det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
    if (det == 0)
      return false;
    det = 1.0 / det;
    for (i = 0; i < 16; i++)
      invOut[i] = inv[i] * det;
    return true;
  }

}


/*

  //
  // Might be worth comparing with optimizations enabled how inverse compares to this if need more speed from it:
  //
  bool gluInvertMatrix(const float m[16], float invOut[16])
  {
    float inv[16], det;
    int i;

    inv[0] = m[5]  * m[10] * m[15] - 
      m[5]  * m[11] * m[14] - 
      m[9]  * m[6]  * m[15] + 
      m[9]  * m[7]  * m[14] +
      m[13] * m[6]  * m[11] - 
      m[13] * m[7]  * m[10];

    inv[4] = -m[4]  * m[10] * m[15] + 
      m[4]  * m[11] * m[14] + 
      m[8]  * m[6]  * m[15] - 
      m[8]  * m[7]  * m[14] - 
      m[12] * m[6]  * m[11] + 
      m[12] * m[7]  * m[10];

    inv[8] = m[4]  * m[9] * m[15] - 
      m[4]  * m[11] * m[13] - 
      m[8]  * m[5] * m[15] + 
      m[8]  * m[7] * m[13] + 
      m[12] * m[5] * m[11] - 
      m[12] * m[7] * m[9];

    inv[12] = -m[4]  * m[9] * m[14] + 
      m[4]  * m[10] * m[13] +
      m[8]  * m[5] * m[14] - 
      m[8]  * m[6] * m[13] - 
      m[12] * m[5] * m[10] + 
      m[12] * m[6] * m[9];

    inv[1] = -m[1]  * m[10] * m[15] + 
      m[1]  * m[11] * m[14] + 
      m[9]  * m[2] * m[15] - 
      m[9]  * m[3] * m[14] - 
      m[13] * m[2] * m[11] + 
      m[13] * m[3] * m[10];

    inv[5] = m[0]  * m[10] * m[15] - 
      m[0]  * m[11] * m[14] - 
      m[8]  * m[2] * m[15] + 
      m[8]  * m[3] * m[14] + 
      m[12] * m[2] * m[11] - 
      m[12] * m[3] * m[10];

    inv[9] = -m[0]  * m[9] * m[15] + 
      m[0]  * m[11] * m[13] + 
      m[8]  * m[1] * m[15] - 
      m[8]  * m[3] * m[13] - 
      m[12] * m[1] * m[11] + 
      m[12] * m[3] * m[9];

    inv[13] = m[0]  * m[9] * m[14] - 
      m[0]  * m[10] * m[13] - 
      m[8]  * m[1] * m[14] + 
      m[8]  * m[2] * m[13] + 
      m[12] * m[1] * m[10] - 
      m[12] * m[2] * m[9];

    inv[2] = m[1]  * m[6] * m[15] - 
      m[1]  * m[7] * m[14] - 
      m[5]  * m[2] * m[15] + 
      m[5]  * m[3] * m[14] + 
      m[13] * m[2] * m[7] - 
      m[13] * m[3] * m[6];

    inv[6] = -m[0]  * m[6] * m[15] + 
      m[0]  * m[7] * m[14] + 
      m[4]  * m[2] * m[15] - 
      m[4]  * m[3] * m[14] - 
      m[12] * m[2] * m[7] + 
      m[12] * m[3] * m[6];

    inv[10] = m[0]  * m[5] * m[15] - 
      m[0]  * m[7] * m[13] - 
      m[4]  * m[1] * m[15] + 
      m[4]  * m[3] * m[13] + 
      m[12] * m[1] * m[7] - 
      m[12] * m[3] * m[5];

    inv[14] = -m[0]  * m[5] * m[14] + 
      m[0]  * m[6] * m[13] + 
      m[4]  * m[1] * m[14] - 
      m[4]  * m[2] * m[13] - 
      m[12] * m[1] * m[6] + 
      m[12] * m[2] * m[5];

    inv[3] = -m[1] * m[6] * m[11] + 
      m[1] * m[7] * m[10] + 
      m[5] * m[2] * m[11] - 
      m[5] * m[3] * m[10] - 
      m[9] * m[2] * m[7] + 
      m[9] * m[3] * m[6];

    inv[7] = m[0] * m[6] * m[11] - 
      m[0] * m[7] * m[10] - 
      m[4] * m[2] * m[11] + 
      m[4] * m[3] * m[10] + 
      m[8] * m[2] * m[7] - 
      m[8] * m[3] * m[6];

    inv[11] = -m[0] * m[5] * m[11] + 
      m[0] * m[7] * m[9] + 
      m[4] * m[1] * m[11] - 
      m[4] * m[3] * m[9] - 
      m[8] * m[1] * m[7] + 
      m[8] * m[3] * m[5];

    inv[15] = m[0] * m[5] * m[10] - 
      m[0] * m[6] * m[9] - 
      m[4] * m[1] * m[10] + 
      m[4] * m[2] * m[9] + 
      m[8] * m[1] * m[6] - 
      m[8] * m[2] * m[5];

    det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];

    if (det == 0)
      return false;

    det = 1.0 / det;

    for (i = 0; i < 16; i++)
      invOut[i] = inv[i] * det;

    return true;
  }
*/


