import * as THREE from 'three';
import { WebGLRenderTarget } from 'three';

class MeshTransmissionMaterial extends THREE.MeshPhysicalMaterial {
  constructor({
    samples = 6,
    transmissionSampler = false,
    chromaticAberration = 0.05,
    transmission = 0,
    _transmission = 1,
    transmissionMap = null,
    roughness = 0,
    thickness = 0,
    thicknessMap = null,
    attenuationDistance = Infinity,
    attenuationColor = new THREE.Color('white'),
    anisotropicBlur = 0.1,
    time = 0,
    distortion = 0.0,
    distortionScale = 0.5,
    temporalDistortion = 0.0,
    buffer = null,
    backside = false,
    side = THREE.FrontSide,
    backsideThickness = 0,
    backsideEnvMapIntensity = 1,
    resolution = 256,
    backsideResolution = 256,
    background = null,
    anisotropy = 0,
    // anisotropicBlur = 0,
  } = {}) {
    super();
    this.uniforms = {
      chromaticAberration: { value: chromaticAberration },
      transmission: { value: transmission },
      _transmission: { value: _transmission },
      transmissionMap: { value: transmissionMap },
      roughness: { value: roughness },
      thickness: { value: thickness },
      thicknessMap: { value: thicknessMap },
      attenuationDistance: { value: attenuationDistance },
      attenuationColor: { value: attenuationColor },
      anisotropicBlur: { value: anisotropicBlur },
      time: { value: time },
      distortion: { value: distortion },
      distortionScale: { value: distortionScale },
      temporalDistortion: { value: temporalDistortion },
      buffer: { value: buffer }
    };
    this.backside = backside;
    this.side = side;
    this.backsideThickness = backsideThickness;
    this.backsideEnvMapIntensity = backsideEnvMapIntensity;
    this.resolution = resolution;
    this.backsideResolution = backsideResolution;
    this.background = background;
    this.anisotropy = anisotropy;
    this.anisotropicBlur = anisotropicBlur;

    this.fboBack = new WebGLRenderTarget(backsideResolution, backsideResolution);
    this.fboMain = new WebGLRenderTarget(resolution, resolution);

    this.onBeforeCompile = shader => {
      shader.uniforms = {
        ...shader.uniforms,
        ...this.uniforms
      };

      if (transmissionSampler) shader.defines.USE_SAMPLER = '';
      else shader.defines.USE_TRANSMISSION = '';

      shader.fragmentShader = `
        uniform float chromaticAberration;
        uniform float anisotropicBlur;
        uniform float time;
        uniform float distortion;
        uniform float distortionScale;
        uniform float temporalDistortion;
        uniform sampler2D buffer;

        vec3 random3(vec3 c) {
          float j = 4096.0*sin(dot(c,vec3(17.0, 59.4, 15.0)));
          vec3 r;
          r.z = fract(512.0*j);
          j *= .125;
          r.x = fract(512.0*j);
          j *= .125;
          r.y = fract(512.0*j);
          return r-0.5;
        }

        uint hash( uint x ) {
          x += ( x << 10u );
          x ^= ( x >>  6u );
          x += ( x <<  3u );
          x ^= ( x >> 11u );
          x += ( x << 15u );
          return x;
        }

        uint hash( uvec2 v ) { return hash( v.x ^ hash(v.y) ); }
        uint hash( uvec3 v ) { return hash( v.x ^ hash(v.y) ^ hash(v.z) ); }
        uint hash( uvec4 v ) { return hash( v.x ^ hash(v.y) ^ hash(v.z) ^ hash(v.w) ); }

        float floatConstruct( uint m ) {
          const uint ieeeMantissa = 0x007FFFFFu;
          const uint ieeeOne = 0x3F800000u;
          m &= ieeeMantissa;
          m |= ieeeOne;
          float f = uintBitsToFloat( m );
          return f - 1.0;
        }

        float randomBase( float x ) { return floatConstruct(hash(floatBitsToUint(x))); }
        float randomBase( vec2  v ) { return floatConstruct(hash(floatBitsToUint(v))); }
        float randomBase( vec3  v ) { return floatConstruct(hash(floatBitsToUint(v))); }
        float randomBase( vec4  v ) { return floatConstruct(hash(floatBitsToUint(v))); }
        float rand(float seed) {
          float result = randomBase(vec3(gl_FragCoord.xy, seed));
          return result;
        }

        const float F3 =  0.3333333;
        const float G3 =  0.1666667;

        float snoise(vec3 p) {
          vec3 s = floor(p + dot(p, vec3(F3)));
          vec3 x = p - s + dot(s, vec3(G3));
          vec3 e = step(vec3(0.0), x - x.yzx);
          vec3 i1 = e*(1.0 - e.zxy);
          vec3 i2 = 1.0 - e.zxy*(1.0 - e);
          vec3 x1 = x - i1 + G3;
          vec3 x2 = x - i2 + 2.0*G3;
          vec3 x3 = x - 1.0 + 3.0*G3;
          vec4 w, d;
          w.x = dot(x, x);
          w.y = dot(x1, x1);
          w.z = dot(x2, x2);
          w.w = dot(x3, x3);
          w = max(0.6 - w, 0.0);
          d.x = dot(random3(s), x);
          d.y = dot(random3(s + i1), x1);
          d.z = dot(random3(s + i2), x2);
          d.w = dot(random3(s + 1.0), x3);
          w *= w;
          w *= w;
          d *= w;
          return dot(d, vec4(52.0));
        }

        float snoiseFractal(vec3 m) {
          return 0.5333333* snoise(m)
                +0.2666667* snoise(2.0*m)
                +0.1333333* snoise(4.0*m)
                +0.0666667* snoise(8.0*m);
        }\n` + shader.fragmentShader;

      shader.fragmentShader = shader.fragmentShader.replace('#include <transmission_pars_fragment>', `
        #ifdef USE_TRANSMISSION
          uniform float _transmission;
          uniform float thickness;
          uniform float attenuationDistance;
          uniform vec3 attenuationColor;
          #ifdef USE_TRANSMISSIONMAP
            uniform sampler2D transmissionMap;
          #endif
          #ifdef USE_THICKNESSMAP
            uniform sampler2D thicknessMap;
          #endif
          uniform vec2 transmissionSamplerSize;
          uniform sampler2D transmissionSamplerMap;
          uniform mat4 modelMatrix;
          uniform mat4 projectionMatrix;
          varying vec3 vWorldPosition;
          vec3 getVolumeTransmissionRay( const in vec3 n, const in vec3 v, const in float thickness, const in float ior, const in mat4 modelMatrix ) {
            vec3 refractionVector = refract( - v, normalize( n ), 1.0 / ior );
            vec3 modelScale;
            modelScale.x = length( vec3( modelMatrix[ 0 ].xyz ) );
            modelScale.y = length( vec3( modelMatrix[ 1 ].xyz ) );
            modelScale.z = length( vec3( modelMatrix[ 2 ].xyz ) );
            return normalize( refractionVector ) * thickness * modelScale;
          }
          float applyIorToRoughness( const in float roughness, const in float ior ) {
            return roughness * clamp( ior * 2.0 - 2.0, 0.0, 1.0 );
          }
          vec4 getTransmissionSample( const in vec2 fragCoord, const in float roughness, const in float ior ) {
            float framebufferLod = log2( transmissionSamplerSize.x ) * applyIorToRoughness( roughness, ior );            
            #ifdef USE_SAMPLER
              #ifdef texture2DLodEXT
                return texture2DLodEXT(transmissionSamplerMap, fragCoord.xy, framebufferLod);
              #else
                return texture2D(transmissionSamplerMap, fragCoord.xy, framebufferLod);
              #endif
            #else
              return texture2D(buffer, fragCoord.xy);
            #endif
          }
          vec3 applyVolumeAttenuation( const in vec3 radiance, const in float transmissionDistance, const in vec3 attenuationColor, const in float attenuationDistance ) {
            if ( isinf( attenuationDistance ) ) {
              return radiance;
            } else {
              vec3 attenuationCoefficient = -log( attenuationColor ) / attenuationDistance;
              vec3 transmittance = exp( - attenuationCoefficient * transmissionDistance );
              return transmittance * radiance;
            }
          }
          vec4 getIBLVolumeRefraction( const in vec3 n, const in vec3 v, const in float roughness, const in vec3 diffuseColor,
            const in vec3 specularColor, const in float specularF90, const in vec3 position, const in mat4 modelMatrix,
            const in mat4 viewMatrix, const in mat4 projMatrix, const in float ior, const in float thickness,
            const in vec3 attenuationColor, const in float attenuationDistance ) {
            vec3 transmissionRay = getVolumeTransmissionRay( n, v, thickness, ior, modelMatrix );
            vec3 refractedRayExit = position + transmissionRay;
            vec4 ndcPos = projMatrix * viewMatrix * vec4( refractedRayExit, 1.0 );
            vec2 refractionCoords = ndcPos.xy / ndcPos.w;
            refractionCoords += 1.0;
            refractionCoords /= 2.0;
            vec4 transmittedLight = getTransmissionSample( refractionCoords, roughness, ior );
            vec3 attenuatedColor = applyVolumeAttenuation( transmittedLight.rgb, length( transmissionRay ), attenuationColor, attenuationDistance );
            vec3 F = EnvironmentBRDF( n, v, specularColor, specularF90, roughness );
            return vec4( ( 1.0 - F ) * attenuatedColor * diffuseColor, transmittedLight.a );
          }
        #endif\n`);

      shader.fragmentShader = shader.fragmentShader.replace('#include <transmission_fragment>', `
        material.transmission = _transmission;
        material.transmissionAlpha = 1.0;
        material.thickness = thickness;
        material.attenuationDistance = attenuationDistance;
        material.attenuationColor = attenuationColor;
        #ifdef USE_TRANSMISSIONMAP
          material.transmission *= texture2D( transmissionMap, vUv ).r;
        #endif
        #ifdef USE_THICKNESSMAP
          material.thickness *= texture2D( thicknessMap, vUv ).g;
        #endif
        
        vec3 pos = vWorldPosition;
        float runningSeed = 0.0;
        vec3 v = normalize( cameraPosition - pos );
        vec3 n = inverseTransformDirection( normal, viewMatrix );
        vec3 transmission = vec3(0.0);
        float transmissionR, transmissionB, transmissionG;
        float randomCoords = rand(runningSeed++);
        float thickness_smear = thickness * max(pow(roughnessFactor, 0.33), anisotropicBlur);
        vec3 distortionNormal = vec3(0.0);
        vec3 temporalOffset = vec3(time, -time, -time) * temporalDistortion;
        if (distortion > 0.0) {
          distortionNormal = distortion * vec3(snoiseFractal(vec3((pos * distortionScale + temporalOffset))), snoiseFractal(vec3(pos.zxy * distortionScale - temporalOffset)), snoiseFractal(vec3(pos.yxz * distortionScale + temporalOffset)));
        }
        for (float i = 0.0; i < ${samples}.0; i ++) {
          vec3 sampleNorm = normalize(n + roughnessFactor * roughnessFactor * 2.0 * normalize(vec3(rand(runningSeed++) - 0.5, rand(runningSeed++) - 0.5, rand(runningSeed++) - 0.5)) * pow(rand(runningSeed++), 0.33) + distortionNormal);
          transmissionR = getIBLVolumeRefraction(
            sampleNorm, v, material.roughness, material.diffuseColor, material.specularColor, material.specularF90,
            pos, modelMatrix, viewMatrix, projectionMatrix, material.ior, material.thickness  + thickness_smear * (i + randomCoords) / float(${samples}),
            material.attenuationColor, material.attenuationDistance
          ).r;
          transmissionG = getIBLVolumeRefraction(
            sampleNorm, v, material.roughness, material.diffuseColor, material.specularColor, material.specularF90,
            pos, modelMatrix, viewMatrix, projectionMatrix, material.ior  * (1.0 + chromaticAberration * (i + randomCoords) / float(${samples})) , material.thickness + thickness_smear * (i + randomCoords) / float(${samples}),
            material.attenuationColor, material.attenuationDistance
          ).g;
          transmissionB = getIBLVolumeRefraction(
            sampleNorm, v, material.roughness, material.diffuseColor, material.specularColor, material.specularF90,
            pos, modelMatrix, viewMatrix, projectionMatrix, material.ior * (1.0 + 2.0 * chromaticAberration * (i + randomCoords) / float(${samples})), material.thickness + thickness_smear * (i + randomCoords) / float(${samples}),
            material.attenuationColor, material.attenuationDistance
          ).b;
          transmission.r += transmissionR;
          transmission.g += transmissionG;
          transmission.b += transmissionB;
        }
        transmission /= ${samples}.0;
        totalDiffuse = mix( totalDiffuse, transmission.rgb, material.transmission );\n`);
    };

    Object.keys(this.uniforms).forEach(name => Object.defineProperty(this, name, {
      get: () => this.uniforms[name].value,
      set: v => this.uniforms[name].value = v
    }));
  }

  render(renderer, scene, camera) {
    const oldBg = scene.background;
    const oldEnvMapIntensity = this.envMapIntensity;
    const oldToneMapping = renderer.toneMapping;

    renderer.toneMapping = THREE.NoToneMapping;
    if (this.background) scene.background = this.background;

    if (this.backside) {
      scene.overrideMaterial = new THREE.MeshBasicMaterial({ colorWrite: false });
      renderer.setRenderTarget(this.fboBack);
      renderer.render(scene, camera);
      scene.overrideMaterial = null;

    //   this.buffer = this.fboBack.texture;
      this.thickness = this.backsideThickness;
      this.side = THREE.BackSide;
      this.envMapIntensity = this.backsideEnvMapIntensity;

      renderer.setRenderTarget(this.fboMain);
      renderer.render(scene, camera);

      this.thickness = this.uniforms.thickness.value;
      this.side = THREE.FrontSide;
      this.envMapIntensity = oldEnvMapIntensity;

      scene.background = oldBg;
      renderer.setRenderTarget(null);
      renderer.toneMapping = oldToneMapping;
    } else {
      renderer.setRenderTarget(this.fboMain);
      renderer.render(scene, camera);
      renderer.setRenderTarget(null);
    }

    this.buffer = this.fboMain.texture;
  }
}

export { MeshTransmissionMaterial };
