//
//  BumpMapping.cpp
//  BumpMapping
//
//  Created by John Ryland on 2/10/17.
//  Copyright © 2017 John Ryland. All rights reserved.
//

#include "../Framework/Framework.h"
#include "../Platform/Paths.h"


DECLARE_PROGRAM_UNIFORMS(BumpProgram)
  DECLARE_UNIFORM(mat4, modelViewMatrix)
  DECLARE_UNIFORM(mat3, normalMatrix)
  DECLARE_UNIFORM(mat4, modelViewProjectionMatrix)
  DECLARE_UNIFORM(sampler2D, texture__ao__Id)
  DECLARE_UNIFORM(sampler2D, texture_col1_Id)
  DECLARE_UNIFORM(sampler2D, texture_col2_Id)
  DECLARE_UNIFORM(sampler2D, texture_disp_Id)
  DECLARE_UNIFORM(sampler2D, texture_glos_Id)
  DECLARE_UNIFORM(sampler2D, texture_norm_Id)
  DECLARE_UNIFORM(sampler2D, texture_refl_Id)
DECLARE_PROGRAM_UNIFORMS_END


// pos 3, tangent-space quat 4, uv 2  ->  possible to fit in 9 float values - is 8 possible?
DECLARE_VERTEX(BumpVertex)
  DECLARE_ATTRIB(vec3, pos, GL_FALSE)
  DECLARE_ATTRIB(vec3, normal, GL_FALSE)
  DECLARE_ATTRIB(vec3, tangent, GL_FALSE)
  //DECLARE_ATTRIB(vec3, binormal, GL_FALSE)
  DECLARE_ATTRIB(vec2, uv, GL_FALSE)
DECLARE_VERTEX_END


class BumpMapping : public DemoContext
{
public:
  ~BumpMapping() override;
  void prepare() override;
  void update(float a_seconds) override;
  void draw() override;
  void onResize(const vec2f& a_shape) override;

private:
  ProgramContext2<BumpProgram, BumpVertex> m_demoModelContext;
  float m_projectionMatrix[16];
};


#define SHADER_VARS R"(
varying vec2 texCoord;
varying vec3 lightVec1;
varying vec3 lightVec2;
//varying vec3 halfVec;
//varying vec3 eyeVec;
varying vec3 camVec;
//varying vec3 col;
varying float lightDistSqr1;
varying float lightDistSqr2;
)"

const char* s_bumpedVertexShader =
SHADER_VARS
R"(
mat3 calculateTagentSpace()
{
    // TBN -> 'tangent space' -> basically it is 3 ortho normal vectors which define the basis vectors or 3 axis -> eg: a rotation matrix
    // This 3x3 rotation matrix (or 3 vec3 vectors) can also be represented more simply using a quaternion apparently -> TODO: figure out as quots
    // Also apparently the tangent is best calculated by an exporter and is part of the model data
    vec3 n = normalize(normalMatrix * normal).xyz;                // normal
    vec3 t = normalize(normalMatrix * tangent).xyz;               // tangent
    vec3 b = cross(n, t);                                         // bi-normal
    //col = binormal;
    return mat3(t,b,n);
}

void calculateTagentSpaceLighting(mat3 tbn)
{
    vec3 t = tbn[0];
    vec3 b = tbn[1];
    vec3 n = tbn[2];
    vec3 vertexPos = vec3(modelViewMatrix * vec4(pos, 1.0));     // vertex position
    vec3 cameraPos = vec3(0.0, 0.0,    0.0);                     // camera position
    vec3 lightPos1 = vec3(4.0, 8.0,   -7.0);                     // light position
    vec3 lightPos2 = vec3(-8.0, -4.0,  5.0);                     // light position
    
    vec3 camDir     = normalize(cameraPos - vertexPos);
    //vec3 halfVector = normalize(vertexPos + lightDir);
    
    // transform light and half angle vectors by tangent basis
    //eyeVec   =  normalize( vec3( dot(vertexPos,  t), dot(vertexPos,  b), dot(vertexPos,  n) ) );
    //halfVec  =  normalize( vec3( dot(halfVector, t), dot(halfVector, b), dot(halfVector, n) ) );
    camVec   =             vec3( dot(camDir,     t), dot(camDir,     b), dot(camDir,     n) );
    
    vec3 lightDir1  = normalize(lightPos1 - vertexPos);           // light direction
    vec3 lightDir2  = normalize(lightPos2 - vertexPos);           // light direction
    lightVec1 =            vec3( dot(lightDir1,  t), dot(lightDir1,  b), dot(lightDir1,  n) );
    lightVec2 =            vec3( dot(lightDir2,  t), dot(lightDir2,  b), dot(lightDir2,  n) );
    lightDistSqr1 = length(lightPos1 - vertexPos);
    lightDistSqr1 = lightDistSqr1 * lightDistSqr1;
    lightDistSqr2 = length(lightPos2 - vertexPos);
    lightDistSqr2 = lightDistSqr2 * lightDistSqr2;
    
}

  void main()
  {
    texCoord = uv;
    gl_Position = modelViewProjectionMatrix * vec4(pos, 1.0);
    mat3 tbn = calculateTagentSpace();
    calculateTagentSpaceLighting(tbn);
  }
)";




const char* s_bumpedFragmentShader =
SHADER_VARS
R"(
/*
  vec3 Gsub(vec3 v) // Sub Function of G
  {
    float k = ((roughness + 1) * (roughness + 1)) / 8;
    float fdotv = dot(fNormal, v);
    return vec3((fdotv) / ((fdotv) * (1.0 - k) + k));
  }
  vec3 G(vec3 l, vec3 v, vec3 h) // Geometric Attenuation Term - Schlick Modified (k = a/2)
  {
    return Gsub(l) * Gsub(v);
  }
*/

#if 1
  const float parallaxScale = 0.03; // ~0.1

  vec2 parallaxMapping(in vec3 V, in vec2 T, out float parallaxHeight)
  {
    
    // determine number of layers from angle between V and N
    const float minLayers =  2.0;
    const float maxLayers = 30.0;
    float numLayers = mix(maxLayers, minLayers, abs(dot(vec3(0.0, 0.0, 1.0), V)));
    
    // height of each layer
    float layerHeight = 1.0 / numLayers;
    // depth of current layer
    float currentLayerHeight = 0.0;
    // shift of texture coordinates for each iteration
    vec2 dtex = parallaxScale * V.xy /* / V.z */ / numLayers;
    
    // current texture coordinates
    vec2 currentTextureCoords = T;
    
    // get first depth from heightmap
    float heightFromTexture = texture2D(texture_disp_Id, currentTextureCoords).x;// texture(u_heightTexture, currentTextureCoords).r;
    
    // while point is above surface
    while (heightFromTexture > currentLayerHeight)
    {
      // to the next layer
      currentLayerHeight += layerHeight;
      // shift texture coordinates along vector V
      currentTextureCoords -= dtex;
      // get new depth from heightmap
      heightFromTexture = texture2D(texture_disp_Id, currentTextureCoords).r;
    }

    ///////////////////////////////////////////////////////////
    // previous texture coordinates
    vec2 prevTCoords = currentTextureCoords + dtex;
    // heights for linear interpolation
    float nextH  = heightFromTexture - currentLayerHeight;
    float prevH  = texture2D(texture_disp_Id, prevTCoords).r
                           - currentLayerHeight + layerHeight;

    // proportions for linear interpolation
    float weight = nextH / (nextH - prevH);
 
    // interpolation of texture coordinates
    vec2 finalTexCoords = prevTCoords * weight + currentTextureCoords * (1.0-weight);
    // interpolation of depth values
    parallaxHeight = currentLayerHeight + prevH * weight + nextH * (1.0 - weight);
    return finalTexCoords;
    // return results
    //parallaxHeight = currentLayerHeight;
    //return currentTextureCoords;
  }

float parallaxSoftShadowMultiplier(in vec3 L, in vec2 initialTexCoord, in float initialHeight)
{
   float shadowMultiplier = 1.0;

   const float minLayers = 15.0;
   const float maxLayers = 30.0;

   // calculate lighting only for surface oriented to the light source
   if(dot(vec3(0.0, 0.0, 1.0), L) > 0.0)
   {
      // calculate initial parameters
      float numSamplesUnderSurface  = 0.0;
      shadowMultiplier  = 0.0;
      float numLayers  = mix(maxLayers, minLayers, abs(dot(vec3(0.0, 0.0, 1.0), L)));
      float layerHeight  = initialHeight / numLayers;
      vec2 texStep  = parallaxScale * L.xy / L.z / numLayers;

      // current parameters
      float currentLayerHeight  = initialHeight - layerHeight;
      vec2 currentTextureCoords  = initialTexCoord + texStep;
      float heightFromTexture  = texture2D(texture_disp_Id, currentTextureCoords).r;
      float stepIndex  = 1.0;

      // while point is below depth 0.0 )
      while(currentLayerHeight > 0.0)
      {
         // if point is under the surface
         if(heightFromTexture < currentLayerHeight)
         {
            // calculate partial shadowing factor
            numSamplesUnderSurface  += 1.0;
            float newShadowMultiplier  = (currentLayerHeight - heightFromTexture) *
                                             (1.0 - stepIndex / numLayers);
            shadowMultiplier  = max(shadowMultiplier, newShadowMultiplier);
         }

         // offset to the next layer
         stepIndex  += 1.0;
         currentLayerHeight  -= layerHeight;
         currentTextureCoords  += texStep;
         heightFromTexture  = texture2D(texture_disp_Id, currentTextureCoords).r;
      }

      // Shadowing factor should be 1 if there were no points under the surface
      if(numSamplesUnderSurface < 1.0)
      {
         shadowMultiplier = 1.0;
      }
      else
      {
         shadowMultiplier = 1.0 - shadowMultiplier;
      }
   }
   return shadowMultiplier;
}
#endif

  void main()
  {
    vec3 V = normalize(camVec);
    vec3 L1 = normalize(lightVec1);
    vec3 L2 = normalize(lightVec2);
    float parallaxHeight;
    vec2 T = texCoord;
    float shadowMultiplier = 1.0;

    //T -= 0.1 * V.xy * texture2D(texture_disp_Id, T).x;
    T = parallaxMapping(V, T, parallaxHeight);
    shadowMultiplier = parallaxSoftShadowMultiplier(L1, texCoord, parallaxHeight - 0.01);

    
    float _ao_ = texture2D(texture__ao__Id, T).x;
    vec4  col1 = texture2D(texture_col1_Id, T);  // ambient
    vec4  col2 = texture2D(texture_col2_Id, T);  // albedo
    float glos = texture2D(texture_glos_Id, T).x;
    vec3  norm = texture2D(texture_norm_Id, T).xyz;
    float refl = texture2D(texture_refl_Id, T).x;
 
    //norm = (2.0 * normalize(norm)) - vec3(1.0);
    norm = normalize((2.0 * norm) - vec3(1.0));
      
    float NdotL1 = dot(norm, L1);// * 50.0 / lightDistSqr1;
    float NdotL2 = dot(norm, L2);// * 50.0 / lightDistSqr2;

    float ambientFactor = 0.2;
    float diffuseFactor = clamp(NdotL1 + NdotL2, 0.0, 1.0);
    float specularFactor = 0.0;

/*
    if (NdotL1 > 0.2)
    {
      //specularFactor = glos * pow(1.5 * dot(reflect(-L, norm), V), refl * 50.0);
      //  specularFactor = glos * clamp(pow(1.8 * dot(normalize(halfVec), norm), refl * 30.0), 0.0, 1.0);
      //col += vec4(1.0 - glos) * col2 * diffuseFactor * 100.0 / lightDistSqr;
    }
*/

    gl_FragColor = vec4(col2.xyz * /* _ao_ * */ (ambientFactor + (diffuseFactor + specularFactor) * pow(shadowMultiplier, 4.0)), 1.0);
  }
)";


BumpMapping::~BumpMapping()
{
  m_demoModelContext.shutdown();
}


void BumpMapping::prepare()
{
  // TODO: perhaps the texture list can be parameter to the setup
  const int textureCount = 7;
  const char* files[textureCount] =
  {
    "-diffuse.png",
    "-diffuse_var1.png",
    "-diffuse_var2.png",
    "-displacement.png",
    "-gloss.png",
    "-normals.png",
    "-reflection.png"
    /*
    "__ao_.png",
    "_col1.png",
    "_col2.png",
    "_disp.png",
    "_glos.png",
    "_norm.png",
    "_refl.png"
    */
  };
  
  OptionsType options;
  m_demoModelContext.setup(s_bumpedVertexShader, s_bumpedFragmentShader, Medium, options, textureCount);

  std::vector<vec3f> vertexes;
  CreateBumpCube(vertexes);
  m_demoModelContext.m_vertexArray.m_vertexData.clear();
  m_demoModelContext.m_vertexArray.m_vertexData.reserve(vertexes.size() / 2);
  for (int i = 0; i < vertexes.size(); i += 4)
  {
    vec3f v = vertexes[i+0];
    vec3f n = vertexes[i+1];
    vec3f t = vertexes[i+2];
    //vec3f b = vertexes[i+3];
    vec3f uv = vertexes[i+3];

    //vec3attrib v3{ 1.0, 2.0, 3.0, 4.0 };
    //BumpVertex vert{ v3 };
    
    m_demoModelContext.m_vertexArray.m_vertexData.push_back(BumpVertex{
     {v.x,v.y,v.z},  {n.x,n.y,n.z},  {t.x,t.y,t.z}, /* b.x,b.y,b.z, */  {uv.x*0.5f,uv.y*0.5f}   });
  }

  m_demoModelContext.update();
  m_demoModelContext.setFlag(clearColor | clearDepth | enableCullFace | enableDepthTest);
  m_demoModelContext.setBackgroundColor(0.0f, 0.0f, 0.0f);

  std::string basename = "brick01";
  basename = "48";//2287";//48";//2287";//"48";//;
  for (int i = 0; i < textureCount; i++)
  {
    std::vector<uint8_t> rawImageData;
    std::vector<uint8_t> decodedImageData;
    uint32_t imageW, imageH;
    std::string fileName = basename + files[i];
    loadFile(fileName.c_str(), rawImageData);
    decodePNG(decodedImageData, imageW, imageH, rawImageData.data(), rawImageData.size());
    m_demoModelContext.setTextureData(i, imageW, imageH, decodedImageData.data());
  }
  //m_demoModelContext.m_uniforms.m_textureId = 0;
  m_demoModelContext.m_uniforms.m_texture__ao__Id = 0;
  m_demoModelContext.m_uniforms.m_texture_col1_Id = 1;
  m_demoModelContext.m_uniforms.m_texture_col2_Id = 2;
  m_demoModelContext.m_uniforms.m_texture_disp_Id = 3;
  m_demoModelContext.m_uniforms.m_texture_glos_Id = 4;
  m_demoModelContext.m_uniforms.m_texture_norm_Id = 5;
  m_demoModelContext.m_uniforms.m_texture_refl_Id = 6;
}


void BumpMapping::onResize(const vec2f& a_shape)
{
  Math::makePerspectiveMatrix4x4(m_projectionMatrix, Math::degreesToRadians(45.0f), a_shape.x / a_shape.y, 0.1f, 100.0f);
}


void BumpMapping::update(float a_seconds)
{
  float baseModelViewMatrix[16];
  float modelViewProjectionMatrix[16];

  float trans[3] = { 0.0f, 0.0f, -20.0f };
  float rotate[3] = { fmod(a_seconds*23.f, 360.0f), fmod(a_seconds*37.f, 360.0f), fmod(a_seconds*46.0f, 360.0f)};

  Math::translationRotationScaleToMatrix4x4(baseModelViewMatrix, trans, rotate, 5.0f);
  Math::multiplyMatrix4x4(modelViewProjectionMatrix, m_projectionMatrix, baseModelViewMatrix);

  //memcpy(m_demoModelContext.m_uniforms.m_normalMatrix.m, baseModelViewMatrix, sizeof(float)*12);
  float normalMatrix[16];
  float normalMatrix2[16];
  Math::makeIdentityMatrix4x4(normalMatrix);
  //Math::translateMatrix4x4(normalMatrix, trans);
  Math::rotateMatrix4x4(normalMatrix, rotate);
  //Math::transposeMatrix4x4(normalMatrix2, normalMatrix);
  memcpy(m_demoModelContext.m_uniforms.m_normalMatrix.m, normalMatrix, sizeof(float)*12);
/*
  float normalMatrix2[16];
  Math::inverseMatrix4x4(normalMatrix2, normalMatrix);
  Math::transposeMatrix4x4(normalMatrix, normalMatrix2);
  //Math::inverseMatrix4x4(normalMatrix, baseModelViewMatrix);
  //Math::transposeMatrix4x4(m_demoModelContext.m_uniforms.m_normalMatrix.m, normalMatrix);
  //Math::transposeMatrix4x4(m_demoModelContext.m_uniforms.m_normalMatrix.m, normalMatrix);
*/
/*
  float normalMatrix[9];
  Math::matrix4x4ToNormalMatrix3x3(normalMatrix, baseModelViewMatrix);
  for (int i = 0; i < 9; i++)
    m_demoModelContext.m_uniforms.m_normalMatrix.m[(i/3)*4 + (i%3)] = normalMatrix[i];
*/

  memcpy(m_demoModelContext.m_uniforms.m_modelViewMatrix.m, baseModelViewMatrix, sizeof(float)*16);
  memcpy(m_demoModelContext.m_uniforms.m_modelViewProjectionMatrix.m, modelViewProjectionMatrix, sizeof(float)*16);
}


void BumpMapping::draw()
{
  m_demoModelContext.draw();
}


REGISTER_DEMO_CONTEXT("Bump Mapping", BumpMapping)


