/*
====================================================================================================

    Copyright (C) 2021 RRe36

    All Rights Reserved unless otherwise explicitly stated.


    By downloading this you have agreed to the license and terms of use.
    These can be found inside the included license-file
    or here: https://rre36.com/copyright-license

    Violating these terms may be penalized with actions according to the Digital Millennium
    Copyright Act (DMCA), the Information Society Directive and/or similar laws
    depending on your country.

====================================================================================================
*/

/* RENDERTARGETS: 0,6 */
layout(location = 0) out vec4 sceneColor;
layout(location = 1) out vec4 temporalData;

/*
temporal anti aliasing based on
- bsl shaders
- chocapic13 shaders
- unreal 4
*/

#include "/lib/head.glsl"

const bool colortex6Clear   = false;

in vec2 uv;

uniform sampler2D colortex0;
uniform sampler2D colortex6;

uniform sampler2D depthtex1;
uniform sampler2D depthtex2;

uniform float frameTime;
uniform float viewHeight;
uniform float viewWidth;
uniform float nightVision;
uniform float rainStrength;

uniform vec2 pixelSize, viewSize;
uniform vec2 taaOffset;

uniform vec3 cameraPosition, previousCameraPosition;

uniform mat4 gbufferModelView, gbufferModelViewInverse;
uniform mat4 gbufferProjection, gbufferProjectionInverse;
uniform mat4 gbufferPreviousModelView, gbufferPreviousProjection;

vec2 temporalReprojection(vec2 uv, float depth, bool hand) {
    vec4 pos    = vec4(uv, depth, 1.0)*2.0-1.0;
        pos     = gbufferProjectionInverse*pos;
        pos    /= pos.w;
        pos     = gbufferModelViewInverse*pos;

    vec4 ppos   = pos + vec4(cameraPosition-previousCameraPosition, 0.0) * float(hand);
        ppos    = gbufferPreviousModelView*ppos;
        ppos    = gbufferPreviousProjection*ppos;

    return (ppos.xy/ppos.w)*0.5+0.5;
}

//3x3 screenpos sampling based on chocapic13's taa
vec3 screenpos3x3(sampler2D depth) {
    vec2 dx     = vec2(pixelSize.x, 0.0);
    vec2 dy     = vec2(0.0, pixelSize.y);

    vec3 dtl    = vec3(uv, 0.0)  + vec3(-pixelSize, texture(depth, uv - dx - dy).x);
    vec3 dtc    = vec3(uv, 0.0)  + vec3(0.0, -pixelSize.y, texture(depth, uv - dy).x);
    vec3 dtr    = vec3(uv, 0.0)  + vec3(pixelSize.x, -pixelSize.y, texture(depth, uv - dy + dx).x);

    vec3 dml    = vec3(uv, 0.0)  + vec3(-pixelSize.x, 0.0, texture(depth, uv - dx).x);
    vec3 dmc    = vec3(uv, 0.0)  + vec3(0.0, 0.0, texture(depth, uv).x);
    vec3 dmr    = vec3(uv, 0.0)  + vec3(0.0, pixelSize.y,  texture(depth, uv + dx).x);

    vec3 dbl    = vec3(uv, 0.0)  + vec3(-pixelSize.x, pixelSize.y, texture(depth, uv + dy - dx).x);
    vec3 dbc    = vec3(uv, 0.0)  + vec3(0.0, pixelSize.y, texture(depth, uv + dy).x);
    vec3 dbr    = vec3(uv, 0.0)  + vec3(pixelSize.x, pixelSize.y, texture(depth, uv + dy + dx).x);

    vec3 dmin   = dmc;

    dmin    = dmin.z > dtc.z ? dtc : dmin;
    dmin    = dmin.z > dtr.z ? dtr : dmin;

    dmin    = dmin.z > dml.z ? dml : dmin;
    dmin    = dmin.z > dtl.z ? dtl : dmin;
    dmin    = dmin.z > dmr.z ? dmr : dmin;

    dmin    = dmin.z > dbl.z ? dbl : dmin;
    dmin    = dmin.z > dbc.z ? dbc : dmin;
    dmin    = dmin.z > dbr.z ? dbr : dmin;

    return dmin;
}

vec4 textureCatmullRom(sampler2D tex, vec2 uv) {   //~5fps
    vec2 res    = textureSize(tex, 0);

    vec2 coord  = uv*res;
    vec2 coord1 = floor(coord - 0.5) + 0.5;

    vec2 f      = coord-coord1;

    vec2 w0     = f * (-0.5 + f * (1.0 - (0.5 * f)));
    vec2 w1     = 1.0 + sqr(f) * (-2.5 + (1.5 * f));
    vec2 w2     = f * (0.5 + f * (2.0 - (1.5 * f)));
    vec2 w3     = sqr(f) * (-0.5 + (0.5 * f));

    vec2 w12    = w1+w2;
    vec2 delta12 = w2 * rcp(w12);

    vec2 uv0    = (coord1 - vec2(1.0)) * pixelSize;
    vec2 uv3    = (coord1 + vec2(1.0)) * pixelSize;
    vec2 uv12   = (coord1 + delta12) * pixelSize;

    vec4 col    = vec4(0.0);
        col    += textureLod(tex, vec2(uv0.x, uv0.y), 0)*w0.x*w0.y;
        col    += textureLod(tex, vec2(uv12.x, uv0.y), 0)*w12.x*w0.y;
        col    += textureLod(tex, vec2(uv3.x, uv0.y), 0)*w3.x*w0.y;

        col    += textureLod(tex, vec2(uv0.x, uv12.y), 0)*w0.x*w12.y;
        col    += textureLod(tex, vec2(uv12.x, uv12.y), 0)*w12.x*w12.y;
        col    += textureLod(tex, vec2(uv3.x, uv12.y), 0)*w3.x*w12.y;

        col    += textureLod(tex, vec2(uv0.x, uv3.y), 0)*w0.x*w3.y;
        col    += textureLod(tex, vec2(uv12.x, uv3.y), 0)*w12.x*w3.y;
        col    += textureLod(tex, vec2(uv3.x, uv3.y), 0)*w3.x*w3.y;

    return clamp(col, 0.0, 65535.0);
}


#define taaBlendWeight 0.1          //[0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0]
#define taaMotionRejection 1.0      //[0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0]
#define taaAntiGhosting 1.0         //[0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0]
#define taaAntiFlicker 0.5          //[0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0]
#define taaLumaRejection 1.0        //[0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0]
#define taaHueRejection 1.0         //[0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0]
#define taaCatmullRom

/* geometric hue angle calculation */
float rgbHue(vec3 rgb) {
    float hue;
    if (rgb.x == rgb.y && rgb.y == rgb.z) hue = 0.0;
    else hue = (180.0 * rcp(pi)) * atan(2.0 * rgb.x - rgb.y - rgb.z, sqrt(3.0) * (rgb.y - rgb.z));

    if (hue < 0.0) hue = hue + 360.0;

    return clamp(hue, 0.0, 360.0);
}
float centerHue(float hue, float center) {
    float hueCentered = hue - center;

    if (hueCentered < -180.0) hueCentered += 360.0;
    else if (hueCentered > 180.0) hueCentered -= 360.0;

    return hueCentered;
}
float rgbSaturation(vec3 rgb) {
    float minrgb    = minOf(rgb);
    float maxrgb    = maxOf(rgb);

    return (max(maxrgb, 1e-10) - max(minrgb, 1e-10)) / max(maxrgb, 1e-2);
}

vec3 temporalAntiAliasing(vec3 sceneColor, float sceneDepth) {
    vec3 screen3x3  = screenpos3x3(depthtex1);
    float depth2    = texture(depthtex2, uv).x;
    vec2 historyPos     = temporalReprojection(uv, sceneDepth, depth2==sceneDepth);

    vec2 pixelDist    = 0.5-abs(fract((historyPos-uv)*viewSize)-0.5);

    if (clamp(historyPos, 0.0, 1.0) != historyPos) return sceneColor;

    vec3 coltl  = textureLod(colortex0, uv+vec2(-pixelSize.x, -pixelSize.y), 0).rgb;
	vec3 coltm  = textureLod(colortex0, uv+vec2( 0.0,         -pixelSize.y), 0).rgb;
	vec3 coltr  = textureLod(colortex0, uv+vec2( pixelSize.x, -pixelSize.y), 0).rgb;
	vec3 colml  = textureLod(colortex0, uv+vec2(-pixelSize.x, 0.0         ), 0).rgb;
	vec3 colmr  = textureLod(colortex0, uv+vec2( pixelSize.x, 0.0         ), 0).rgb;
	vec3 colbl  = textureLod(colortex0, uv+vec2(-pixelSize.x,  pixelSize.y), 0).rgb;
	vec3 colbm  = textureLod(colortex0, uv+vec2( 0.0,          pixelSize.x), 0).rgb;
	vec3 colbr  = textureLod(colortex0, uv+vec2( pixelSize.x,  pixelSize.y), 0).rgb;

	vec3 minCol = min(sceneColor,min(min(min(coltl,coltm),min(coltr,colml)),min(min(colmr,colbl),min(colbm,colbr))));
	vec3 maxCol = max(sceneColor,max(max(max(coltl,coltm),max(coltr,colml)),max(max(colmr,colbl),max(colbm,colbr))));

    #ifdef taaCatmullRom
        sceneColor          = textureCatmullRom(colortex0, uv - taaOffset * 0.5).rgb;
        vec3 historyColor   = textureCatmullRom(colortex6, historyPos).rgb;
    #else
        sceneColor          = textureLod(colortex0, uv - taaOffset * 0.5, 0).rgb;
        vec3 historyColor   = texture(colortex6, historyPos).rgb;
    #endif

    vec3 historyClamped     = clamp(historyColor, minCol, maxCol);

    float clamped   = distance(historyColor, historyClamped) / getLuma(historyColor) * 0.5;

    vec2 velocity   = (uv-historyPos)/pixelSize;

    //flicker reduction
    float lumaDiff  = distance(historyColor, sceneColor) / getLuma(historyColor);

    float sceneHue  = rgbHue(sceneColor);
    float hueDiff   = centerHue(rgbHue(historyColor), sceneHue);

    float hueAntiflicker = max0(abs(hueDiff) - 90.0) / 90.0;
    float hueRejection  = 1.0 - (max0(90.0 - abs(hueDiff)) / 90.0);
        hueRejection   *= taaHueRejection / (1.0 + cube(lumaDiff) * 2.0);

        lumaDiff    = sqr(lumaDiff) * taaLumaRejection;

    float hueDiffCoeff  = saturate(rgbSaturation(sceneColor));

    float movementRejection     = (0.12 + clamped * taaAntiGhosting) * saturate(length(velocity) * taaMotionRejection);

    float baseBlendWeight       = taaBlendWeight;
    if (!landMask(sceneDepth)) baseBlendWeight = 0.1;

    float temporalWeight    = 1.0 - saturate((baseBlendWeight + movementRejection) / (1.0 + (lumaDiff + hueAntiflicker * hueDiffCoeff) * taaAntiFlicker));
        temporalWeight     *= max(1.0 / (1.0 + hueRejection * sqr(hueDiffCoeff)), 0.1);

    historyClamped.rgb      = mix(sceneColor.rgb, historyClamped.rgb, temporalWeight);

    return historyClamped;
}

int decodeMatID16(float x) {
    return int(x*65535.0);
}

void main() {
        sceneColor   = textureLod(colortex0, uv, 0);
    float sceneDepth = texture(depthtex1, uv).x;

    #ifdef taaEnabled
    vec3 temporal   = temporalAntiAliasing(sceneColor.rgb, sceneDepth);
        sceneColor.rgb = temporal;
    #else
    const vec3 temporal   = vec3(0.0);
    #endif

    sceneColor      = clamp16F(sceneColor);
    temporalData    = clamp16F(vec4(temporal, texture(colortex6, uv).a));
}