//
#include <Shaders/VisionCommon.inc>
#include <Shaders/DeferredShadingHelpers.inc>

#define EPSILON 0.00001f

struct PS_IN
{
  float4   ProjPos    : SV_Position;
  
  float4   LocalRay   : TEXCOORD0;
  float4   ViewRay    : TEXCOORD1;
  float4   ShapePos   : TEXCOORD2;
  float4   ClipPlanes : TEXCOORD3;

  #if defined(_VISION_PS3) || defined(_VISION_PSP2)
    float2 ScreenPos : WPOS;
  #elif defined(_VISION_DX11)
    float4 ScreenPos : SV_Position;
  #elif defined(_VISION_XENON)
    float4 ScreenPos : TEXCOORD4;
  #else
    float2 ScreenPos : VPOS;
  #endif
};

#ifdef _VISION_DX11
cbuffer g_GlobalConstantBufferObject : register (b1)
{
  float4x4 MatModelView : packoffset(c0);
}
cbuffer g_GlobalConstantBufferUser : register (b2)
{
  float  Density   : packoffset(c0);
  float3 BaseColor : packoffset(c1);
  
$ifdef PLANAR_FALLOFF
  float4 PlanarFalloffStart  : packoffset(c2);
  float4 PlanarFalloffEnd    : packoffset(c3);
$endif

$ifdef RADIAL_FALLOFF
  float4 RadialFalloffStart  : packoffset(c4);
  float2 RadialFalloffRadius : packoffset(c5);
$endif

$ifdef AXIAL_FALLOFF
  float4 AxialFalloffStart   : packoffset(c6);
  float4 AxialFalloffEnd     : packoffset(c7);
  float4 AxialFalloffRadius  : packoffset(c8);
$endif
    
  float2 InvScreenSize       : packoffset(c9);

  float3 NoiseOffset         : packoffset(c10);
  float3 NoiseScale          : packoffset(c11);
  
  float UseNoise             : packoffset(c12);
  float3 ColorOffset         : packoffset(c13);
}

#elif defined(_VISION_PS3) || defined(_VISION_PSP2) || defined(_VISION_WIIU)

  float  Density               : register(c32);
  float3 BaseColor             : register(c33);

  $ifdef PLANAR_FALLOFF
    float4 PlanarFalloffStart  : register(c34);
    float4 PlanarFalloffEnd    : register(c35);
  $endif

  $ifdef RADIAL_FALLOFF
    float4 RadialFalloffStart  : register(c36);
    float2 RadialFalloffRadius : register(c37);
  $endif

  $ifdef AXIAL_FALLOFF
    float4 AxialFalloffStart   : register(c38);
    float4 AxialFalloffEnd     : register(c39);
    float4 AxialFalloffRadius  : register(c40);
  $endif
  
  float4x4 MatModelView        : register(c41);
  float2 InvScreenSize         : register(c45);

  #ifndef _VISION_PSP2
    float3 NoiseOffset         : register(c46);
    float3 NoiseScale          : register(c47);
    float UseNoise             : register(c48);
  #endif
  
  float3 ColorOffset           : register(c49);
#else

  float  Density;
  float3 BaseColor;
  
  $ifdef PLANAR_FALLOFF
    float4 PlanarFalloffStart;
    float4 PlanarFalloffEnd;
  $endif

  $ifdef RADIAL_FALLOFF
    float4 RadialFalloffStart;
    float2 RadialFalloffRadius;
  $endif

  $ifdef AXIAL_FALLOFF
    float4 AxialFalloffStart;
    float4 AxialFalloffEnd;
    float4 AxialFalloffRadius;
  $endif
  
  float4x4 MatModelView;
  float2 InvScreenSize;

  float3 NoiseOffset;
  float3 NoiseScale;
  
  float UseNoise;
  float3 ColorOffset;
#endif


#ifdef _VISION_DX11
  Texture2D DepthTexture        : register(t0);
  sampler   DepthTextureSampler : register(s0);
  Texture3D <float4> NoiseTexture        : register(t1);
  sampler            NoiseTextureSampler : register(s1);
#else
  sampler2D DepthTexture        : register(s0);
  sampler3D NoiseTexture        : register(s1);
#endif

struct QuadraticRoots
{
  float2 vals;
  int count;
};

// Returns the roots of ax^2+bx+c, ordered by magnitude.
QuadraticRoots Quadratic(float a, float b, float c)
{
  QuadraticRoots ret;
  
  float discr = b*b - 4*a*c;
  if (discr < 0)
  {
    ret.count = 0;    
    ret.vals.xy = float2(1,-1);
  }
  else if (discr < EPSILON)
  {
    ret.count = 1;
    ret.vals.xy = -0.5f * b / a;
  }
  else
  {
    ret.count = 2;
    
    //float2 roots = (-b + float2(-1,+1) * sqrt(discr)) / (2 * a);
    float q = -0.5f * (b + sign(b) * sqrt(discr));
    float2 roots = float2(q/a, c/q);
    
    ret.vals.x = min(roots.x, roots.y);
    ret.vals.y = max(roots.x, roots.y);
  }
  return ret;
}


// Transmissivity in a non-scattering medium
float BeersLaw(float alpha, float dist)
{
  return exp(-alpha*dist);
}

// 2,3-order Pade approximation.
float BeersLaw_Pade(float alpha, float dist)
{
  float x = -alpha * dist;
  return (60 + x*(6 + x) * 5) / (60 + x*(-36 + x*(90 - x)));
}

float Absorption(float density, float distance)
{
  float A = 0;
  
$ifdef EXPONENTIAL
  A = 1 - BeersLaw(density, distance);
//$elif EXPONENTIAL_APPROX
//  A = 1 - BeersLaw_Pade(density, distance);
$elif LINEAR
  // Exact at distance=1/density.
  float s = 0.36787944117144232159552377016146f / density;
  A = saturate(1 - s / distance);
$endif

  return A;
}


// ----------------------------------------------------------------------------------------------
// Ray-object intersection functions
// ----------------------------------------------------------------------------------------------
//
// The functions have names of the form TestRay[Primitive]() and TestRay[Primitive]_Fast().
// 
// The former are general intersection routines.  The latter assume canonical sizes and orientations
// (e.g. unit sphere/cube, axis-aligned cones of unit size, etc).
// ----------------------------------------------------------------------------------------------

struct IntersectionInfo
{
  float2 pts;
  int count;
};

// ----------------------------------------------------------------------------------------------
// Sphere
// ----------------------------------------------------------------------------------------------

IntersectionInfo TestRaySphere(float3 dir, float3 pos, float radius)
{
  IntersectionInfo info;
  
  // |d*t-p|^2 = r^2
  // t^2 |d|^2 - 2t d.p + |p|^2 = r^2
  // t^2 - 2t d.p + |p|^2 - r^2 = 0
  QuadraticRoots q = Quadratic(1, -2.0f * dot(dir, pos), dot(pos, pos) - radius * radius);
  
  info.pts.xy = q.vals;
  info.count = q.count;
  return info;
}

// Assumes unit radius
IntersectionInfo TestRaySphere_Fast(float3 dir, float3 pos)
{
  IntersectionInfo info;
  
  // |d*t-p|^2 = 1
  // t^2 |d|^2 - 2t d.p + |p|^2 = 1
  // t^2 - 2t d.p + |p|^2 - 1 = 0
  QuadraticRoots q = Quadratic(1, -2.0f * dot(dir, pos), dot(pos, pos) - 1);
  
  info.pts.xy = q.vals;
  info.count = q.count;
  return info;
}


// ----------------------------------------------------------------------------------------------
// Box
// ----------------------------------------------------------------------------------------------

// A branchless/vectorized version of the Mueller slab test
IntersectionInfo TestRayBox(float3 dir, float3 pos, float3 halfExt)
{
  float3 rcpDir = 1.0f / dir;
  
  float3 lt = (pos - halfExt) * rcpDir;
  float3 rt = (pos + halfExt) * rcpDir;
  
  float3 mins = min(lt, rt);
  float3 maxs = max(lt, rt);

  IntersectionInfo info;  
  info.pts[0] = max(max(mins.x, mins.y), mins.z);
  info.pts[1] = min(min(maxs.x, maxs.y), maxs.z);
  info.count  = ((info.pts[1] >= 0.f) && (info.pts[1] >= info.pts[0])) ? 2 : 0;
  
  return info;
}

// Assumes unit extents
IntersectionInfo TestRayBox_Fast(float3 dir, float3 pos)
{
  float3 rcpDir = 1.0f / dir;
  
  float3 lt = (pos - 1) * rcpDir;
  float3 rt = (pos + 1) * rcpDir;
  
  float3 mins = min(lt, rt);
  float3 maxs = max(lt, rt);

  IntersectionInfo info;  
  info.pts[0] = max(max(mins.x, mins.y), mins.z);
  info.pts[1] = min(min(maxs.x, maxs.y), maxs.z);
  
  // always report two intersections, the non-intersections will be culled in the raytracer 
  // by the out of bounds extents check (i.e. t1 < t0).
  info.count = 2;
  return info;
}


// ----------------------------------------------------------------------------------------------
// Capsule
// ----------------------------------------------------------------------------------------------
/*
IntersectionInfo TestRayCapsule(float3 dir, float3 A, float3 B, float radius)
{
  IntersectionInfo info;
  float  segmentLength = length(B - A);
  float3 lineDir = (B - A) / segmentLength;
  
  // Infinite cylinder
  // 
  // Given a line with direction u and point p and a ray through the origin x(t) = t v,
  // simply compute the squared distance in terms of t and set it to r^2.
  //
  //   |u (x(t) - p) . u + p - x(t)|^2
  // = |u (t v - p) . u + p - t v|^2
  // = |t (u (v . u) - v) + (p - u (u . p))|^2
  // = t^2 |u (v . u) - v|^2 + 2 t (u (v . u) - v) . (p - u (u . p)) + |p - u (u . p)|^2
  //
  float  cosT = dot(dir,lineDir);
  float3 d1 = lineDir * cosT - dir;
  float3 d2 = A - lineDir * dot(lineDir, A);
    
  QuadraticRoots q = Quadratic(dot(d1,d1), 2 * dot(d1, d2), dot(d2, d2) - radius * radius);
  info.count = q.count;
  info.pts.xy = q.vals;
  
  if (info.count == 0)
  {
    // ray doesn't intersect the infinite cylinder, nothing to do
  }
  else
  {
    // Eliminate roots outside the planar caps
    float2 planeTest = float2(dot(dir * info.pts[0] - A, lineDir), dot(dir * info.pts[1] - A, lineDir));
    
    bool2 check = planeTest >= 0 && planeTest <= segmentLength;
    if (all(check))
    {
      // ray intersects the capsule wall twice
      return info;
    }

    // Check the lower spherical cap
    IntersectionInfo cap1 = TestRaySphere(dir, A, radius);
    info.pts = check ? info.pts : (cap1.count ? cap1.pts : info.pts.yx);
    
    // Check the upper spherical cap
    IntersectionInfo cap2 = TestRaySphere(dir, B, radius);
    if (cap2.count)
    {
      info.pts[0] = min(info.pts[0], cap2.pts[0]);
      info.pts[1] = max(info.pts[1], cap2.pts[1]);
    }
  }
  
  return info;
}
*/

// Assumes radius=0.5, height=1, centered at pos and oriented along the z axis
IntersectionInfo TestRayCapsule_Fast(float3 dir, float3 pos)
{
  float3 A = float3(0,0,-0.5) + pos;
  float3 B = float3(0,0, 0.5) + pos;
  
  IntersectionInfo info;
  
  // Ray-circle intersection in the XY plane
  float  cosT = dir.z;
  float2 d1 = -dir.xy;
  float2 d2 = A.xy;
  
  // store the coefficients, since we can reuse them later
  float a = dot(dir.xy,dir.xy);
  float b = -2 * dot(dir.xy,pos.xy);
  float c = dot(pos.xy,pos.xy) - 0.25;

  // scale the coefficients in order to avoid precision problems at very small values of dsq/psq
  QuadraticRoots qroots = Quadratic(a*100, b*100, c*100);
  info.count = qroots.count;
  info.pts.xy = qroots.vals;
  
  if (info.count == 0)
  {
    // ray doesn't intersect the infinite cylinder, nothing to do
	return info;
  }
  else
  {
    // Eliminate roots outside the planar caps
    float2 planeTest = dir.z * info.pts - pos.z;
    
    bool2 check = abs(planeTest) <= 0.5;
    if (all(check))
    {
      // ray intersects the capsule wall twice
      return info;
    }
    
    info.pts = check ? info.pts : info.pts.yx;

    // Check the lower spherical cap (see TestRaySphere() for details)
    float discr;
    
    b -= 2 * dir.z * A.z;
    c += A.z*A.z;
        
    discr = b*b - 4*c;
        
    if (discr >= 0)
    {
      float q = -0.5f * (b + sign(b) * sqrt(discr));
      float2 roots = float2(q, c/q);
      
      info.pts[0] = min(info.pts[0], min(roots.x,roots.y));
      info.pts[1] = max(info.pts[1], max(roots.x,roots.y));
    }
    
    // Check the upper spherical cap
    b -= dir.z * 2;
    c += pos.z * 2;
    discr = b*b - 4*c;
    
    if (discr >= 0)
    {
      float q = -0.5f * (b + sign(b) * sqrt(discr));
      float2 roots = float2(q, c/q);
      
      info.pts[0] = min(info.pts[0], min(roots.x,roots.y));
      info.pts[1] = max(info.pts[1], max(roots.x,roots.y));
    }
  }
  
  return info;
}

// ----------------------------------------------------------------------------------------------
// Cylinder
// ----------------------------------------------------------------------------------------------

IntersectionInfo TestRayCylinder(float3 dir, float3 A, float3 B, float radius)
{
  IntersectionInfo info;
  float  segmentLength = length(B - A);
  float3 lineDir = (B - A) / segmentLength;
  
  // Infinite cylinder
  // 
  // Given a line with direction u and point p and a ray through the origin x(t) = t v,
  // simply compute the squared distance in terms of t and set it to r^2.
  //
  //   |u (x(t) - p) . u + p - x(t)|^2
  // = |u (t v - p) . u + p - t v|^2
  // = |t (u (v . u) - v) + (p - u (u . p))|^2
  // = t^2 |u (v . u) - v|^2 + 2 t (u (v . u) - v) . (p - u (u . p)) + |p - u (u . p)|^2
  //
  float  cosT = dot(dir,lineDir);
  float3 d1 = lineDir * cosT - dir;
  float3 d2 = A - lineDir * dot(lineDir, A);
    
  QuadraticRoots q = Quadratic(dot(d1,d1), 2 * dot(d1, d2), dot(d2, d2) - radius * radius);
  info.count = q.count;
  info.pts.xy = q.vals;
  
  // Check the planar caps
  float2 caps = float2(dot(A,lineDir), dot(B,lineDir)) / cosT;
  info.pts[0] = max(info.pts[0], min(caps.x, caps.y));
  info.pts[1] = min(info.pts[1], max(caps.x, caps.y));
  
  // Discard invalid intersections
  if ((info.pts[0] > info.pts[1]) || info.pts[1] < 0.f)
  {
    info.count = 0;
  }
  
  return info;
}


// Assumes unit radius, height 2, centered at pos and oriented along the z axis
IntersectionInfo TestRayCylinder_Fast(float3 dir, float3 pos)
{
  IntersectionInfo info;
  
  float3 A = float3(0,0,-1) + pos;
  float3 B = float3(0,0,1) + pos;
  
  // Compute a line-circle intersection in the XY plane.
  float dsq = dot(dir.xy, dir.xy);
  float psq = dot(pos.xy, pos.xy);
  
  // scale the coefficients in order to avoid precision problems at very small values of dsq/psq
  QuadraticRoots q = Quadratic(dsq*100, -2 * dot(dir.xy, pos.xy)*100, (psq - 1)*100);
  info.count = q.count;
  info.pts.xy = q.vals;
    
  // Check the planar caps
  float2 caps = float2(pos.z - 1, pos.z + 1) / dir.z;
  info.pts[0] = max(info.pts[0], min(caps.x, caps.y));
  info.pts[1] = min(info.pts[1], max(caps.x, caps.y));
  
  return info;
}

// ----------------------------------------------------------------------------------------------
// Cone
// ----------------------------------------------------------------------------------------------

IntersectionInfo TestRayCone(float3 dir, float3 A, float3 B, float radius)
{
  IntersectionInfo info;
  float  segmentLength = length(B - A);
  float3 lineDir = (B - A) / segmentLength;
  
  // Cone
  // 
  // Given a line L with direction u and point p and a ray through the origin f(t) = t v,
  // the squared distance between f(t) and L can be given as
  //
  //   |(f(t) - p) x u|^2
  // = |t v x u + u x p|^2
  // = t^2 |v x u|^2 + 2 t (v x u) . (u x p) + |u x p|^2
  //
  // Let r(t) the squared radius for a given t:
  //
  //   r^2(t)
  // = r^2 ((f(t) - p) . u / |L|)^2
  // = r^2 ((t v - p) . u)^2
  // = r^2 (t^2 (u . v)^2 - 2 t / (u . v) (u . p) + (u . p)^2)
  //
  // Letting the squared distance equal the squared radius yields the quadratic
  // 
  //   a t^2 + b t + c = 0,
  //
  // where
  //
  //   a = |v x u|^2 - r^2 / |L|^2 (u . v)^2
  //   b = 2 [(v x u) . (u x p) + r^2 / |L|^2 (u . v) (u . p)]
  //   c = |u x p|^2 - r^2 / |L|^2 (u . p)^2
  //
  float3 n = cross(dir, lineDir);
  float3 up = cross(lineDir, A);
  
  float  cosT = dot(dir,lineDir);
  float  da   = dot(lineDir, A);
  
  float rsq = radius * radius / (segmentLength * segmentLength);
  float rsq_uv = rsq * cosT * cosT;
  float rsq_up = rsq * da * da;
  
  QuadraticRoots q = Quadratic(dot(n,n) - rsq_uv, 2 * (dot(n, up) + rsq * cosT * da), dot(up,up) - rsq_up);
  info.count = q.count;
  info.pts.xy = q.vals;
  
  // Check the planar caps
  float2 caps = float2(dot(A,lineDir), dot(B,lineDir)) / cosT;
  float2 ds = float2(min(caps.x, caps.y), max(caps.x,caps.y));  
  
  // Eliminate roots in the reflected cone
  float2 coneTest = float2(dot(dir * info.pts[0] - A, lineDir), dot(dir * info.pts[1] - A, lineDir));
  
  if (coneTest.x < 0)
  {
    if (coneTest.y < 0)
    {
      // both solutions are in the reflected cone -- ignore
      info.count = 0;
    }
    else
    {
      // near point in the reflected cone -- shift forward, replace far solution with end cap
      info.pts[0] = info.pts[1];
      info.pts[1] = ds[1];
    }
  }
  else
  {
    if (coneTest.y < 0)
    {
      // far point in the reflected cone -- shift backward, replace near solution with end cap
      info.pts[1] = info.pts[0];
      info.pts[0] = ds[0];
    }
    else
    {
      // both solutions in the non-reflected cone
      info.pts[0] = max(info.pts[0], ds[0]);
      info.pts[1] = min(info.pts[1], ds[1]);
    }
  }

  
  return info;
}

// Assumes unit radius, height 2, centered at pos and oriented along the z axis
IntersectionInfo TestRayCone_Fast(float3 dir, float3 pos)
{
  IntersectionInfo info;
  
  // Compute the intersection between a 2d ray and a circle with a time-varying radius.
  // See TestRayCone() for details.
  float a = dot(dir.xy,dir.xy) - dir.z * dir.z / 4;
  float b = -2 * dot(dir.xy, pos.xy) + dir.z * (pos.z - 1) / 2;
  float c = dot(pos.xy,pos.xy) - (pos.z - 1) * (pos.z - 1) / 4;
  
  // scale the coefficients in order to avoid precision problems at very small values of dsq/psq
  a *= 100;
  b *= 100;
  c *= 100;
  QuadraticRoots q = Quadratic(a, b, c);
  
  info.count = q.count;
  info.pts.xy = q.vals;
  
  // Check the planar caps
  float2 caps = float2(pos.z - 1, pos.z + 1) / dir.z;
  float2 ds = float2(min(caps.x, caps.y), max(caps.x,caps.y));  
  
  // Eliminate roots in the reflected cone
  float2 coneTest = dir.z * info.pts - pos.z + 1;
  
  bool2 check = (coneTest >= 0);
  
  if (all(check))
  {
    // both solutions in the non-reflected cone
    info.pts[0] = max(info.pts[0], ds[0]);
    info.pts[1] = min(info.pts[1], ds[1]);  
  }
  else
  {
    // shift/replace invalid solutions
    info.pts = check ? info.pts.xy : info.pts.yx;
    
    if (check[0])
    {
      info.pts[0] = ds[0];
    }
    else
    {
      info.pts[1] = check[1] ? ds[1] : info.pts[1];
    }
  }

  return info;
}



// Some notes on the math in what follows.
//
// 1. Uniform density
// -------------------
// 
// Given a light ray in a medium, let f(t) be the ratio of outgoing power I_r to incoming power I, after traveling a path of
// length t:
// 
//   I_r = I f(t).
// 
// Since a path of length zero should result in no power loss, we can assume that f(0)=1. Suppose f(t) is known for some t, and
// consider f(t+dt) for some small dt.
// 
// The idea is that with some probability p, the photon will be absorbed. So, we have
// 
//   f(0) = 1
//   f(t+dt) ~= (1 - p dt) f(t).
// 
// Rearranging,
// 
//   [f(t+dt) - f(t)] / dt ~= -p f(t),
// 
// Letting dt go to zero, the left-hand side is just the derivative of f, so we have the following ODE:
// 
//   f'(t) = -p f(t).
// 
// With the initial condition f(0)=1, the solution is
// 
//   f(t) = e^{-p t},
// 
// which is the Beer-Lambert law for transmissivity in uniform media.
// 
// 
// 2. Computing density
// ---------------------
//
// We now prove that the law holds for non-uniform media, as long as p is the average density over a given path. Indeed, suppose
// the density is an arbitrary integrable function q(t), so the differential equation is now
// 
//   g'(t) = -q(t) g(t),
// 
// and the solution, again using the condition g(0)=1, is
// 
//   g(t) = e^{-int q(t) dt}.
// 
// Let f=g, so
// 
//   f(t) = g(t)
//   e^{-p t} = e^{-int q(t) dt}
//   p t = int q(t) dt
//   p = 1/t int q(t) dt,
//
// where the right-hand side is just the average density over a path of length t.
//
//
// 3. Combining density functions
// -------------------------------
// 
// Suppose we have two separate environmental factors that affect the density of the medium along a path, and the photon is absorbed
// with probability p1 due to the first one, and p2 due to the second.
// 
// The probability that it's absorbed due to *either one* is then
// 
//   p_any = 1 - (1 - p1) (1 - p2)
//         = 1 - (1 - p1 - p2 - p1 p2)
//         = p1 + p2 - p1 p2.
// 
// Plugging this into the equation for p yields
// 
//   p = 1/t int p1(t) + p2(t) - p1(t) p2(t) dt
//     = 1/t (int p1(t) dt + int p2(t) dt) + R,
// 
// where
// 
//   R = -1/t int p1(t) p2(t) dt.
// 
// Given upper bounds p1_max and p2_max for p1 and p2 respectively, the upper bound on R over the interval [t1,t2] is simply
// p1_max * p2_max, which is generally negligible. We therefore ignore all integrals of products of density functions, and for
// any number of factors p_k, write
// 
//   p_any = p1 + p2 + ... pn.
//  
// Note that while we're treating densities as additive, the fact that e^{a+b} = e^a e^b means that the resulting transmissivity
// combines multiplicatively.
// 
// 
// 4. Falloff
// -----------
// 
// We define a falloff function in terms of the probability of a photon *not* being absorbed at time t:
// 
//   f'(t) = (1-r(t)) f(t)
//   f(t) = e^{int 1-r(t) dt},
// 
// If we combine falloff in the same way as we combine density, then given a uniform density c and multiple falloff functions 
// r1, r2, ..., rn, the combined density can be written as
// 
//   p = c * (1 - 1/t (int r1(t) dt + int r2(t) dt + ... + int rn(t) dt)).


// Efficiently computes int_t1^t2 a t^2 + b t + c dt
float IntegrateQuadratic(float a, float b, float c, float2 t)
{
  // a t^3 / 3 + b t^2 / 2 + c t
  float2 F = t * (t * (t * 2 * a + 3 * b) + 6 * c) / 6;
  return (F[1] - F[0]);
}

float ComputePlanarFalloff(float3 dir, float3 p1, float3 p2, float2 t)
{
  float2 s;
  
  // compute the planes that define the falloff slab
  float3 n = normalize(p2 - p1);
  float2 d = float2(dot(p1,n), dot(p2,n));

  // discard the segment if it's outside the slab
  float proj = dot(dir, n);
  if (all(t*proj <= d.xx))
  {
    return 0;
  }
  else if (all(t*proj >= d.yy))
  {
    return 1;
  }
  
  // intersect the ray with the slab
  float2 ext = (abs(proj) > 0.0001f) ? d / proj : t;
  
  // clip the traced segment to the ray-slab intersection.
  // the previous two tests ensure that the result is nonempty.
  s[0] = max(t[0], min(ext[0], ext[1]));
  s[1] = min(t[1], max(ext[0], ext[1]));

  // compute falloff along the clipped segment
  float denom = (d[1]-d[0])*(d[1]-d[0]);
  float falloff = IntegrateQuadratic(proj*proj, -2*d[0]*proj, d[0]*d[0], s) / denom;
  
  // t*n.v-d2 = 
  float back;
  if (ext[0] < ext[1])
  {
    back = max(0.0f, t[1] - s[1]);
  }
  else
  {
    back = max(0.0f, s[0] - t[0]);
  }
  return (falloff + back) / (t[1] - t[0]);
}

float ComputeRadialFalloff(float3 dir, float3 pos, float2 r, float2 t)
{
  // compute the intersection of the traced segment with the sphere
  float proj = dot(dir, pos);
  float psq  = dot(pos, pos);
  float2 rsq = r * r;

  float4 s; // outer/inner sphere intersections
    
  // outer sphere
  QuadraticRoots outer = Quadratic(1, -2.0f * proj, psq - rsq[1]);
  
  if (outer.count < 1)
  {
    // Ray doesn't intersect outer sphere -- full falloff.
    return 1;
  }
  outer.vals = max(float2(0.0f, 0.0f), outer.vals);

  s[0] = max(t[0], outer.vals[0]);
  s[3] = min(t[1], outer.vals[1]);

  if (s[3] < s[0])
  {
    // The intersection of the sphere, ray and primitive is empty -- full falloff.
    return 1;
  }
  
  // inner sphere
  QuadraticRoots inner = Quadratic(1, -2.0f * proj, psq - rsq[0]);
  if (inner.count < 1)
  {
    s[1] = s[2] = s[0];
  }
  else
  {
    s[1] = max(s[0], inner.vals[0]);
    s[2] = min(s[3], min(t[1], inner.vals[1]));
  }

  //return 1 - (s[2] - s[1]) / (t[1] - t[0]);
  
  // compute falloff along the clipped segment
  float falloff;
  
 
  // this results in *negative* falloff inside the inner sphere, but works better visually than the correct version.
  // doing the right thing (i.e. assuming that falloff is exactly zero inside the inner sphere) causes too much falloff.
  falloff = IntegrateQuadratic(1, -2.0f * proj, psq - rsq[0], s.xw) / (rsq[1] - rsq[0])
            + (s[0]-t[0]) + (t[1]-s[3]);  // falloff is 1 outside the outer sphere
  return saturate(falloff / (t[1] - t[0]));
}

QuadraticRoots ConeHelper(float2 coneTest, QuadraticRoots roots, float2 ds)
{
  if (coneTest.x < 0)
  {
    if (coneTest.y < 0)
    {
      roots.count = 0;
    }
    else
    {
      roots.vals[0] = roots.vals[1];
      roots.vals[1] = ds[1];
    }
  }
  else
  {
    if (coneTest.y < 0)
    {
      roots.vals[1] = roots.vals[0];
      roots.vals[0] = ds[0];
    }
    else
    {
      // both solutions in the non-reflected cone
    }
  }
  roots.vals[0] = max(roots.vals[0], ds[0]);
  roots.vals[1] = min(roots.vals[1], ds[1]);
	  
  return roots;
}

float ComputeAxialFalloff(float3 dir, float3 A, float3 B, float3 r, float2 t)
{
  // rearrange the vertices so that the smaller radius is at A
  if (r[0] > r[1])
  {
    float3 tmp = A; A = B; B = tmp;
    r.xy = r.yx;
  }

  // compute the vertex (potentially infinite -- that's okay)
  float3 vtx = A - r[0] / (r[1] - r[0]) * (B - A);  
    
  // compute the intersection of the traced segment with an infinite double cone (see TestRayCone for details)
  float  rsegmentLength = 1 / length(B - A);
  float3 lineDir = (B - A) * rsegmentLength;
  
  float3 n = cross(dir, lineDir);
  float3 up = cross(lineDir, A);
  
  float  cosT = dot(dir,lineDir);
  float  da   = dot(lineDir, A);
  
  float  dr   = (r[1] - r[0]) * rsegmentLength;
  
  float drsq = dr * dr;
  float drsq_uv = drsq * cosT * cosT;
  float drsq_up = drsq * da * da;

  float4 s; // cone intersections
  
  
  // Planar caps
  float2 caps = float2(dot(A,lineDir), dot(B,lineDir)) / cosT;
  float2 ds = float2(min(caps.x, caps.y), max(caps.x,caps.y));  
  
  // Outer cone
  float a = dot(n,n) - drsq_uv;
  float b = 2 * (dot(n, up) + drsq * cosT * da - r[0] * dr * cosT);
  float c = dot(up,up) - drsq_up + 2 * r[0] * dr * da - r[0]*r[0];
  QuadraticRoots outer = Quadratic(a, b, c);
    
  // Eliminate roots in the reflected cone
  float2 coneTest = float2(dot(dir * outer.vals[0] - vtx, lineDir), dot(dir * outer.vals[1] - vtx, lineDir));
  outer = ConeHelper(coneTest, outer, ds);

  if (outer.count < 1)
  {
    // Intersection outside the outer cone -- full falloff
    return 1;
  }
  
  outer.vals = max(float2(0.0f, 0.0f), outer.vals);

  s[0] = max(t[0], outer.vals[0]);
  s[3] = min(t[1], outer.vals[1]);

  if (s[3] < s[0])
  {
    // The intersection of the cone, ray and primitive is empty -- full falloff.
    return 1;
  }
  
  // Inner cone
  float innerRatio = 2 * r[2] / (r[0] + r[1]);
  float ir0 = r[0] * innerRatio;
  
  dr *= innerRatio;
  drsq = dr * dr;
  drsq_uv = drsq * cosT * cosT;
  drsq_up = drsq * da * da;  

  a = dot(n,n) - drsq_uv;
  b = 2 * (dot(n, up) + drsq * cosT * da - ir0 * dr * cosT);
  c = dot(up,up) - drsq_up + 2 * ir0 * dr * da - ir0*ir0;
  
  QuadraticRoots inner = Quadratic(a, b, c);
  
  // Eliminate roots in the reflected cone
  coneTest = float2(dot(dir * inner.vals[0] - vtx, lineDir), dot(dir * inner.vals[1] - vtx, lineDir));
  inner = ConeHelper(coneTest, inner, ds);
  
  if (inner.count < 1)
  {
    s[1] = s[2] = s[0];
  }
  else
  {
    s[1] = max(s[0], inner.vals[0]);
    s[2] = min(s[3], min(t[1], inner.vals[1]));
  }

  // compute falloff along the clipped segment
  float falloff;
 
 // this results in *negative* falloff inside the inner cone, but works better visually than the correct version.
  // doing the right thing (i.e. assuming that falloff is exactly zero inside the inner cone) causes too much falloff.
  falloff = IntegrateQuadratic(a,b,c,s.xw) / (r[1]*r[1] - (r[0]*innerRatio)*(r[0]*innerRatio))
            + (s[0]-t[0]) + (t[1]-s[3]);  // falloff is 1 outside the outer sphere
  return saturate(falloff / (t[1] - t[0]));
}

float2 GetScreenTextureCoords(PS_IN In)
{
  float2 uv;
  #ifdef _VISION_DX11
    uv = In.ScreenPos.xy * InvScreenSize.xy;
  #elif defined(_VISION_XENON)
    uv = (In.ScreenPos.xy / In.ScreenPos.w) * 0.5 + 0.5;    
    uv += 0.5f * InvScreenSize.xy;
  #elif defined(_VISION_PS3)
    uv = In.ScreenPos.xy * InvScreenSize.xy;
    uv.y = 1.0f - uv.y;
  #elif defined(_VISION_PSP2)
    uv = In.ScreenPos.xy * InvScreenSize.xy;
  #else
    uv = In.ScreenPos.xy * InvScreenSize.xy + 0.5f * InvScreenSize.xy;
  #endif

  return uv;
}

float4 ps_main( PS_IN In ) : SV_Target
{
  float  rRayLength = 1.0f / length(In.LocalRay.xyz);
  float3 ray = -In.LocalRay.xyz * rRayLength;
  
  // Position and depth extents in camera space
  float3 viewRay = mul(MatModelView, float4(ray,0)).xyz;
  float DepthScale = length(viewRay);
  
  // Position in offset object space  
  float3 pos = In.ShapePos.xyz / In.ShapePos.w;
  IntersectionInfo info;

$ifdef TYPE_SPHERE
  info = TestRaySphere_Fast(ray, pos);
$elif  TYPE_BOX
  info = TestRayBox_Fast(ray, pos);
$elif  TYPE_CYLINDER
  info = TestRayCylinder_Fast(ray, pos);
$elif  TYPE_CONE
  info = TestRayCone_Fast(ray, pos);
$elif  TYPE_CAPSULE
  info = TestRayCapsule_Fast(ray, pos);
$else
  info.pts.xy = 0; info.count = 0;
$endif

  // discard invalid collisions
  clip(info.count - 1.99);

  // clip the near plane to 0
  info.pts[0] = max(0.0f, info.pts[0]);
  
  // scale to world space
  info.pts.xy *= DepthScale;
  
  // clip the far plane to the depth buffer
  float2 scrUV = GetScreenTextureCoords(In);

  float scrDepth = READ_CONVERTED_DEPTH(DepthTexture, DepthTextureSampler, scrUV);
  scrDepth = scrDepth * (In.ClipPlanes.y - In.ClipPlanes.x) + In.ClipPlanes.x;
  
  // shouldn't this be just DepthScale * scrDepth / viewRay.z?
  scrDepth = length(viewRay * scrDepth / viewRay.z);
  
  info.pts[1] = min(info.pts[1], scrDepth);
  
  // we now have the actual depth interval inside the fog shape
  float depth = info.pts[1] - info.pts[0];
  clip(depth);
  
   // Scale the z-clipped extents back into object space
  float2 extents = info.pts / DepthScale;
  
  float averageDensity = Density;
  float falloffDensity = 1;
  
$ifdef PLANAR_FALLOFF
  falloffDensity *= 1 - saturate(ComputePlanarFalloff(ray, pos + PlanarFalloffStart, pos + PlanarFalloffEnd, extents));
$endif

$ifdef RADIAL_FALLOFF
  {
    float4 p = mul(MatModelView, float4(-RadialFalloffStart.xyz,1));
    falloffDensity *= 1 - saturate(ComputeRadialFalloff(-viewRay / DepthScale, p, RadialFalloffRadius.xy, info.pts));
  }  
$endif

$ifdef AXIAL_FALLOFF
  {
    float4 p1 = mul(MatModelView, float4(-AxialFalloffStart.xyz,1));
    float4 p2 = mul(MatModelView, float4(-AxialFalloffEnd.xyz,1));
    falloffDensity *= 1 - saturate(ComputeAxialFalloff(-viewRay / DepthScale, p1, p2, AxialFalloffRadius.xyz, info.pts));
  }  
$endif
  
  averageDensity *= falloffDensity;
  
#ifndef _VISION_PSP2
  if (UseNoise.x > 0.0f)
  {
    float3 noiseUV = ( float3( scrUV, 0 ) + NoiseOffset ) * NoiseScale;
    float noise = vTex3D( NoiseTexture, NoiseTextureSampler, noiseUV ).r;
	#ifdef _VISION_DX11
		float3 noiseUV2 = ( float3( scrUV, 0 ) + NoiseOffset * float3( 0.75f, 1.33f, 0.83f ) ) * NoiseScale * 8;
		float noise2 = vTex3D( NoiseTexture, NoiseTextureSampler, noiseUV2 ).r;
		noise *= ( 1 - noise2 * 0.2f );
	#endif
    averageDensity*= noise;
  }
#endif
  
  // Compute the transmissivity
$ifndef NORMALIZE_DISTANCE
  float Opacity = saturate(Absorption(averageDensity, max(0.0f, depth)));
$else
  float Opacity = Absorption(averageDensity, depth / DepthScale);
$endif

  return float4(BaseColor + ColorOffset, Opacity);
}
