#version 460

#extension GL_EXT_shader_16bit_storage : enable
#extension GL_EXT_shader_explicit_arithmetic_types : enable
#extension GL_EXT_debug_printf : enable
//#define HAS_16BIT_TYPES

#include <shaders/materials/commons.glsl>
#include <shaders/commons_hlsl.glsl>
#include <shaders/geometry_partitioning/raytrace_buffers.glsl>
#include <shaders/geometry_partitioning/raytrace_commons.glsl>

#include <shaders/geometry_partitioning/voxel_texture_build_mips_support.glsl>

layout(std140, row_major) uniform TransformParamsBuffer{
	EntityTransformParams transform_params;
};

#ifdef SPIRV_VULKAN
#if !defined(PARTITION_VOXELIZE) && !defined(PARTITION_RAYTRACE)

layout(location = 1) out Vertex
{
	vec3    vCoords;
	vec3    vNorm;
	vec3    vWorldNorm;
	vec3    vLocalPos;
	vec3    vCameraRelativeWorldPos;
	f16vec4 vColor;
	f16vec2 vUV0;
} vtx_output;

layout(location = 0) out uint instanceID;

#else

layout(location = 1) out Vertex
{
	vec3    vLocalPos;
#ifdef PARTITION_VOXELIZE
	f16vec4 vColor;
	f16vec2 vUV0;
#endif
	uint    vIdx;
} vtx_output;

layout(location = 0) out uint instanceID;

#endif
#endif

#ifdef VISUALIZE_RAYTRACE_GRID

uniform sampler3D s_grid_marker;
uniform sampler3D s_voxel_colors;

vec3 get_vertex(uint idx)
{
	// in this mode we draw instanced triangle strips of 14 vertices. this should be ok for decent perf.

	uint voxel_idx = SYS_InstanceIndex;
	vec3 voxel_p;
	voxel_p.x = float(voxel_idx % GRID_RES);
	voxel_p.y = float((voxel_idx / GRID_RES) % GRID_RES);
	voxel_p.z = float((voxel_idx / (GRID_RES * GRID_RES)) % GRID_RES);

	vec3 p = vec3((ivec3(0x287a, 0x2af, 0x31e3) >> idx) & 1);

	//float scale = in_buckets.sizes[voxel_idx] > 0 ? 0.2 : 0.0;
	float scale = texelFetch(s_grid_marker, ivec3(voxel_p), 0).r > 0 ? 0.9 : 0.0;

	//vec3 size = vec3(min(in_buckets.grid_size.x, min(in_buckets.grid_size.y, in_buckets.grid_size.z)));
	vec3 size = in_bbox_data.grid_size_raytrace.xyz;
	p = vec3(voxel_p) * in_bbox_data.grid_size_raytrace.xyz + in_bbox_data.bbox_raytrace_min.xyz + p * size * scale;

	return p;
}

vec3 get_vertex_normal(uint idx)
{
	return vec3(1.0, 1.0, 0.0);	// rebuild in ps
}

vec2 get_vertex_uv0(uint idx)
{
	return vec2(0.0);
}

vec4 get_vertex_color(uint idx)
{
	return vec4(0.1);
}

#elif defined(VISUALIZE_VOXELIZATION_GRID)

struct VisualizeVoxelizationGridParams
{
	int   lod;
	int   grid_res_for_lod;
	uint  _pad1;
	uint  _pad2;
};

layout(std140, row_major) uniform VisualizeVoxelizationGridParamsBuffer {
	VisualizeVoxelizationGridParams visualize_voxelization_grid;
};


uniform sampler3D s_voxel_colors;
uniform sampler3D s_voxel_colors_filtered;
uniform sampler3D s_voxel_occupancy;
uniform sampler3D s_voxel_occupancy_filtered;

vec3 get_vertex(uint idx)
{
	// in this mode we draw instanced triangle strips of 14 vertices. this should be ok for decent perf.

	uint voxel_idx = SYS_InstanceIndex;
	vec3 voxel_p;
	voxel_p.x = float(voxel_idx % visualize_voxelization_grid.grid_res_for_lod);
	voxel_p.y = float((voxel_idx / visualize_voxelization_grid.grid_res_for_lod) % visualize_voxelization_grid.grid_res_for_lod);
	voxel_p.z = float((voxel_idx / (visualize_voxelization_grid.grid_res_for_lod * visualize_voxelization_grid.grid_res_for_lod)) % visualize_voxelization_grid.grid_res_for_lod);

	vec3 p = vec3((ivec3(0x287a, 0x2af, 0x31e3) >> idx) & 1);

	float scale = texelFetch(s_voxel_occupancy, ivec3(voxel_p), visualize_voxelization_grid.lod).r;

	//scale = scale > 0.0 ? 1.0 : 0.0;// *0.8;

	//vec3 size = vec3(min(in_buckets.grid_size.x, min(in_buckets.grid_size.y, in_buckets.grid_size.z)));
    vec3 size = in_bbox_data.grid_size_voxelize.xyz;
	p = vec3(voxel_p * float(1 << visualize_voxelization_grid.lod)) * in_bbox_data.grid_size_voxelize.xyz + in_bbox_data.bbox_voxelize_min.xyz + p * size * scale * float(1 << visualize_voxelization_grid.lod);
	return p;
}

vec3 get_vertex_normal(uint idx)
{
	return vec3(1.0, 1.0, 0.0);	// rebuild in ps
}

vec2 get_vertex_uv0(uint idx)
{
	return vec2(0.0);
}


vec4 get_vertex_color(uint idx)
{
	uint voxel_idx = SYS_InstanceIndex;
	vec3 voxel_p;
	voxel_p.x = float(voxel_idx % visualize_voxelization_grid.grid_res_for_lod);
	voxel_p.y = float((voxel_idx / visualize_voxelization_grid.grid_res_for_lod) % visualize_voxelization_grid.grid_res_for_lod);
	voxel_p.z = float((voxel_idx / (visualize_voxelization_grid.grid_res_for_lod * visualize_voxelization_grid.grid_res_for_lod)) % visualize_voxelization_grid.grid_res_for_lod);

	vec4 color;
#ifdef VX_USE_RGB9E5
	color.rgb = texelFetch(s_voxel_colors_filtered, ivec3(voxel_p), visualize_voxelization_grid.lod).rgb;
#else
	color.rgb = color_convert_rgbm_rgb(textureLod(s_voxel_colors_filtered, (voxel_p + 0.5) / float(visualize_voxelization_grid.grid_res_for_lod), float(visualize_voxelization_grid.lod)));
#endif
	color.a   = 1.0;
	return color.xyzw + 0.1;
}

#else

vec3 get_vertex(uint idx)
{
    return rt_get_vertex(idx);
}

vec3 get_vertex_normal(uint idx)
{
	return rt_get_vertex_normal(idx);
}

vec2 get_vertex_uv0(uint idx)
{
	return rt_get_vertex_uv0(idx);
}

vec4 get_vertex_color(uint idx)
{
	return rt_get_vertex_color(idx);
}

#endif

void main() {
	
	// TODO: Do we actually need to do anything here? Don't think so, this is not called for instanced draws
	instanceID = 0;

#if defined(PROXY_PASS_STENCIL_PRIMID) && defined(RAYTRACE_PASS)

	vec3 pos = vec3(0.0);
#if 1
	if (SYS_VertexIndex == 0)
	{
		pos = vec3(-1.0, 1.0, 0.0);
	}
	if (SYS_VertexIndex == 1)
	{
		pos = vec3(1.0, 1.0, 0.0);
	}
	if (SYS_VertexIndex == 2)
	{
		pos = vec3(1.0, -1.0, 0.0);
	}
	if (SYS_VertexIndex == 3)
	{
		pos = vec3(-1.0, -1.0, 0.0);
	}
#else
	// emit single fullscreen triangle
	if (SYS_VertexIndex == 0)
	{
		pos = vec3(-3.0, 1.0, 0.0);
	}
	if (SYS_VertexIndex == 1)
	{
		pos = vec3(1.0, 1.0, 0.0);
	}
	if (SYS_VertexIndex == 2)
	{
		pos = vec3(1.0, -3.0, 0.0);
	}
#endif
	vec3 norm = vec3(0.0, 0.0, -1.0);

	vtx_output.vLocalPos = pos;

	vec4 vPos1 = vec4(pos, 1.0);
	vtx_output.vCameraRelativeWorldPos  = vPos1.xyz - transform_params.vCameraPosition;
	vtx_output.vWorldNorm = norm;

	vtx_output.vNorm.x = dot(transform_params.mModelViewInvTrans[0].xyz, norm);
	vtx_output.vNorm.y = dot(transform_params.mModelViewInvTrans[1].xyz, norm);
	vtx_output.vNorm.z = dot(transform_params.mModelViewInvTrans[2].xyz, norm);
	vtx_output.vNorm   = normalize(vtx_output.vNorm);//vNormal;

	vtx_output.vCoords = vPos1.xyz;
	gl_Position = vPos1;
	vtx_output.vColor = f16vec4(1.0);
	vtx_output.vUV0   = f16vec2(0.0);

#else

	uint vtx_idx = rt_mask_vtx_idx(SYS_VertexIndex);
	vec3 pos = get_vertex(vtx_idx);
	vec3 norm = get_vertex_normal(vtx_idx);

	vtx_output.vLocalPos = pos;

#if defined(PARTITION_VOXELIZE) || defined(PARTITION_RAYTRACE)
	vtx_output.vIdx = vtx_idx;
#endif

	//vPos = mModelview * vec4(pos, 1.0);
	vec4 vPos1 = vec4(pos, 1.0);
#if defined(PARTITION_VOXELIZE) || defined(PARTITION_RAYTRACE)
	vec3 vPos = pos.xyz;
#else
	vec3 vPos = vector_transform_by_mat43(pos, transform_params.mModelView);
#endif

#if !defined(PARTITION_VOXELIZE) && !defined(PARTITION_RAYTRACE)
	
	vtx_output.vCameraRelativeWorldPos = vector_transform_by_mat43(pos, transform_params.mModel) - transform_params.vCameraPosition;
	vtx_output.vWorldNorm = (transform_params.mModel * vec4(norm, 0.0)).xyz;

	vtx_output.vNorm.x = dot(transform_params.mModelViewInvTrans[0].xyz, norm);
	vtx_output.vNorm.y = dot(transform_params.mModelViewInvTrans[1].xyz, norm);
	vtx_output.vNorm.z = dot(transform_params.mModelViewInvTrans[2].xyz, norm);
	vtx_output.vNorm   = normalize(vtx_output.vNorm);//vNormal;
#endif

#if defined(PROXY_PASS) || defined(RAYTRACE_PASS)
	// in proxy pass we can do culling for parts which are not raytraced but only reflected
	//int face_material = rt_get_triangle_material(gl_VertexID / 3);
	//if ((materials.material_properties[face_material].flags & MaterialFlag_Raytrace) == 0)
	//	vPos = vec3(0.0);

	vtx_output.vCoords = vPos.xyz;
#endif

	//vPos = vec4(vProjection.x * vPos.x, vProjection.y * vPos.y, dot(vProjection.zw, vPos.zw), -vPos.z);	
	gl_Position       = transform_params.mProjection * vec4(vPos, 1.0);
#ifdef PARTITION_VOXELIZE
	vtx_output.vColor = f16vec4(get_vertex_color(vtx_idx));
	vtx_output.vUV0   = f16vec2(get_vertex_uv0(vtx_idx));
#endif

#endif

#ifdef VISUALIZE_VOXELIZATION_GRID
	vtx_output.vColor = f16vec4(get_vertex_color(vtx_idx));
	vtx_output.vUV0 = vec2(0.0);
#endif
}

