#include "/program/shaders1/common/atmosphere/projection.glsl"

vec2 ray_sphere_intersection(vec3 origin, vec3 dir, float radius){
	float B = dot(origin, dir);
	float C = dot(origin, origin) - radius * radius;
	float D = B * B - C;

	vec2 intersection;

	if (D < 0.0){
		intersection = vec2(-1.0, -1.0);
	} else {
		D = sqrt(D);
		intersection = -B + vec2(-D, D); 
	}

	return intersection;
}

float phase_rayleigh(float x){
    return (3.0 / (16.0 * pi)) * (1.0 + x * x);
}

float phase_mie(float x, float g){
    float gg = g * g;
    float xx = x * x;

    return (3.0 / (8.0 * pi) * (1.0 - gg) / (2.0 + gg)) * (1.0 + xx) / pow(1.0 + gg - 2.0 * g * x, 1.5);
}

float chapman_approximation(float X, float h, float cosZenith){
    float c = sqrt(X + h);
	float c_exp_h = c * exp(-h);

	if (cosZenith >= 0.0){
		return c_exp_h / (c * cosZenith + 1.0);
	} else {
		float x0 = sqrt(1.0 - cosZenith * cosZenith) * (X + h);
		float c0 = sqrt(x0);

		return 2.0 * c0 * exp(X - x0) - c_exp_h / (1.0 - c * cosZenith);
	}
}

float optical_depth_schueler(float h, float H, float radius, float cos_z){
    return H * chapman_approximation(radius / H, h / H, cos_z);
}

vec2 optical_depth(atmosphere_constant ac, vec3 point, vec3 v, vec3 l, float n){
    float rl = length(point);
    float h = rl -  ac.planet_radius;

    vec3 r = point / rl;

    float cos_chi_sun = dot(r, l);
    float cos_chi_ray = dot(r, v * n);

    float optical_depth_sun = optical_depth_schueler(h, ac.rayleigh_height, ac.planet_radius, cos_chi_sun);
    float optical_depth_ray = optical_depth_schueler(h, ac.rayleigh_height, ac.planet_radius, cos_chi_ray);

    return vec2(optical_depth_sun, optical_depth_ray);
}

vec3 atmospheric_scattering(atmosphere_constant ac, vec3 position, vec3 direction, const int steps, inout vec3 transmittance){ 
    const float m = 1.0;

    vec3 view_position = vec3(0.0, ac.planet_radius + cameraPosition.y, 0.0);

    vec2 sphere0 = ray_sphere_intersection(view_position, position, ac.atmosphere_radius);
    vec2 sphere1 = ray_sphere_intersection(view_position, position, ac.planet_radius * 0.99995);

    bool intersection = sphere1.y >= 0.0;

    vec3 position0 = view_position + position * (intersection && sphere1.x < 0.0 ? sphere1.y : max0(sphere0.x));
    vec3 position1 = view_position + position * (intersection && sphere1.x > 0.0 ? sphere1.x : sphere0.y);

    vec3 ray_increment = (position1 - position0) / float(steps);
    vec3 ray_point = position1 - ray_increment;
    vec3 ray_lambda = ac.lambda_rayleigh + ac.lambda_mie + ac.lambda_ozone;

    vec2 optical_depth0 = optical_depth(ac, position1, position, direction, intersection ? -m : m);
    vec3 scattering = vec3(0.0);

    for(int i = 1; i < steps; i++, ray_point -= ray_increment){
        vec2 optical_depth1 = optical_depth(ac, ray_point, position, direction, intersection ? -m : m);

        if(optical_depth1.y > 1e35) break;

        vec3 segs = exp(-ray_lambda * (optical_depth1.x + optical_depth0.x));
        vec3 segt = exp(-ray_lambda * (optical_depth1.y - optical_depth0.y));

        scattering *= segt;
        scattering += exp(-(length(ray_point) - ac.planet_radius) / ac.rayleigh_height) * segs;

        transmittance *= segt;
        optical_depth0 = optical_depth1;
    }

    transmittance = intersection ? vec3(0.0) : transmittance;

    float VdotL = dot(position, direction);
    float phase0 = phase_rayleigh(VdotL);
    float phase1 = phase_mie(VdotL, ac.mie_g);
    float phase2 = 0.25 * rpi;
    
    vec3 sun_scattering = scattering * ac.scattering_coefficient[0] * phase0 * length(ray_increment);
         sun_scattering += scattering * ac.scattering_coefficient[1] * phase1 * length(ray_increment);
         //sun_scattering += scattering * ac.scattering_coefficient[0] * phase2 * length(ray_increment);
		 
	if(intersection){
		sun_scattering = mix(sun_scattering, vec3(luminance(sun_scattering)), wetness);
	}

    return sun_scattering;
}

vec3 atmospheric_scattering(vec3 position, vec3 sun_direction, vec3 moon_direction, const int steps, inout vec3 transmittance){
    atmosphere_constant ac = atmosphere_s();

    #if ATMOSPHERE_TYPE == 0
        vec3 view_scattering = atmospheric_scattering(ac, vec3(0.0, 1.0, 0.0), sun_direction, ac.steps1, transmittance) * ac.sun_luminance;
             view_scattering += atmospheric_scattering(ac, vec3(0.0, 1.0, 0.0), moon_direction, ac.steps1, transmittance) * ac.moon_luminance;

        vec3 sun_scattering = atmospheric_scattering(ac, position, sun_direction, steps, transmittance) * ac.sun_luminance;
             sun_scattering += atmospheric_scattering(ac, position, moon_direction, steps, transmittance) * ac.moon_luminance;

        return sun_scattering + view_scattering * pi;     
    #elif ATMOSPHERE_TYPE == 1
        return vec3(0.0);
    #endif
}

vec3 atmospheric_transmittance(atmosphere_constant ac, vec3 v, vec3 l){
    return exp2(-(ac.lambda_rayleigh + ac.lambda_mie + ac.lambda_ozone) * optical_depth_schueler(l.y, ac.rayleigh_height, ac.planet_radius, v.y));
}

vec3 atmospheric_transmittance(vec3 sun_direction, vec3 moon_direction){
    atmosphere_constant ac = atmosphere_s();

    vec3 view_position = vec3(0.0, 1.0, 0.0);

    float sun_transition = saturate(dot(view_position, sun_direction) * 24.0);
    float moon_transition = saturate(dot(view_position, moon_direction) * 16.0);

    #if defined OVER_WORLD
        vec3 transmittance0 = texture2(colortex2, project_sphere(sun_direction)).xyz;
        vec3 transmittance1 = texture2(colortex2, project_sphere(moon_direction)).xyz;

        vec3 sun_scattering = texture2(colortex1, project_sphere(sun_direction)).xyz;
             sun_scattering = ac.sun_luminance * transmittance0 * sun_transition;
             sun_scattering *= temperature_to_rgb(5778.0);

        vec3 moon_scattering = texture2(colortex1, project_sphere(moon_direction)).xyz;
             moon_scattering = ac.moon_luminance * transmittance1 * moon_transition;
        
        return sun_scattering + moon_scattering;
    #elif defined END_WORLD
        return sunAngle > 0.5 ? vec3(0.0) : temperature_to_rgb(5778.0) * ac.sun_luminance;
    #else
        return vec3(0.0);
    #endif
}