import { buildEnclosureShapes } from "../geometry/buildEnclosureShapes";
import { buildIntersectionShapes } from "../geometry/buildIntersectionShapes";
import {
  findShapeEnclosure,
  findShapeIntersections,
} from "../geometry/intersection";
import { Vector2D } from "../math/vector2D";
import {
  EnclosureType,
  type CollisionInfo,
  type CollisionPair,
  type IntersectionPoint,
} from "../models";
import { CollisionShapeBuilder } from "../shapes/builders/collisionShapeBuilder";
import { intersectSegmentTrees } from "../shapes/intersectSegmentTrees";
import type { Body } from "./body";

export function getActualCollisions(
  potentialCollisions: CollisionPair[],
  startMetric: (name: string) => void,
  endMetric: (name: string) => void,
  recordNumericMetric: (name: string, value: number) => void
): CollisionInfo[] {
  startMetric("findSegmentPairs");
  const { segmentPairCount, enclosureImpossibleCount } =
    findSegmentPairs(potentialCollisions);
  endMetric("findSegmentPairs");
  recordNumericMetric("numSegmentPairs", segmentPairCount);
  recordNumericMetric(
    "numEnclosureImpossibleCollisions",
    enclosureImpossibleCount
  );

  startMetric("intersectSegments");
  const {
    intersectionCollisions,
    possibleEnclosureCollisions,
    unfilteredSegmentCheckCount,
    segmentCheckCount,
    segmentIntersectionCount,
  } = intersectSegmentPairs(potentialCollisions);
  endMetric("intersectSegments");
  recordNumericMetric(
    "numIntersectionCollisions",
    intersectionCollisions.length
  );
  recordNumericMetric(
    "numPossibleEnclosureCollisions",
    possibleEnclosureCollisions.length
  );
  // recordNumericMetric(
  //   "numUnfilteredSegmentCheckCount",
  //   unfilteredSegmentCheckCount
  // );
  // recordNumericMetric("numSegmentCheckCount", segmentCheckCount);
  // recordNumericMetric("numSegmentIntersectionCount", segmentIntersectionCount);

  startMetric("checkEnclosure");
  const enclosureCollisions = findEnclosure(possibleEnclosureCollisions);
  endMetric("checkEnclosure");

  const finalCollisions: CollisionInfo[] = [];

  startMetric("calculateIntersectionCollisionInfo");
  for (const collision of intersectionCollisions) {
    calculateIntersectionCollisionInfo(collision, finalCollisions);
  }
  endMetric("calculateIntersectionCollisionInfo");

  startMetric("calculateEnclosureCollisionInfo");
  for (const collision of enclosureCollisions) {
    calculateEnclosureCollisionInfo(collision, finalCollisions);
  }
  endMetric("calculateEnclosureCollisionInfo");

  return finalCollisions;
}

export function resolveCollisions(
  collisions: CollisionInfo[],
  bodiesWithPositionCorrections: Set<Body>,
  triggers: (() => void)[]
): CollisionInfo[] {
  const resolvedCollisions: CollisionInfo[] = [];

  for (const collision of collisions) {
    const body1 = collision.body1;
    const body2 = collision.body2;

    const hasTriggers1 = body1.addCollision(collision);
    if (hasTriggers1) {
      triggers.push(body1.callCollisionTriggers);
    }

    const hasTriggers2 = body2.addCollision(collision);
    if (hasTriggers2) {
      triggers.push(body2.callCollisionTriggers);
    }

    if (
      !body1.shouldResolveCollision(collision) ||
      !body2.shouldResolveCollision(collision)
    ) {
      continue;
    }

    if (resolveImpulseCollision(collision)) {
      resolvedCollisions.push(collision);
    }

    bodiesWithPositionCorrections.add(body1);
    bodiesWithPositionCorrections.add(body2);
  }

  return resolvedCollisions;
}

function findSegmentPairs(potentialCollisions: CollisionPair[]): {
  segmentPairCount: number;
  enclosureImpossibleCount: number;
} {
  let segmentPairCount = 0;
  let enclosureImpossibleCount = 0;

  for (const collision of potentialCollisions) {
    const results = intersectSegmentTrees(
      collision.body1.shape.getSegmentTree(),
      collision.body2.shape.getSegmentTree(),
      collision.relativeTransformInfo.guestToHost,
      collision.relativeTransformInfo.hostToGuest
    );

    if (results === null) {
      collision.enclosureType = EnclosureType.Impossible;
      enclosureImpossibleCount++;
      continue;
    }

    if (results.length > 0) {
      segmentPairCount += results.length;
      collision.segmentPairs = results;
    }
  }
  return { segmentPairCount, enclosureImpossibleCount };
}

function intersectSegmentPairs(potentialCollisions: CollisionPair[]): {
  intersectionCollisions: CollisionPair[];
  possibleEnclosureCollisions: CollisionPair[];
  unfilteredSegmentCheckCount: number;
  segmentCheckCount: number;
  segmentIntersectionCount: number;
} {
  let unfilteredSegmentCheckCount = 0;
  let segmentCheckCount = 0;
  let segmentIntersectionCount = 0;

  const intersectionCollisions: CollisionPair[] = [];
  const possibleEnclosureCollisions: CollisionPair[] = [];

  for (const collision of potentialCollisions) {
    if (collision.segmentPairs.length === 0) {
      if (collision.enclosureType === EnclosureType.Impossible) {
        continue;
      }
      possibleEnclosureCollisions.push(collision);
      continue;
    }

    unfilteredSegmentCheckCount +=
      collision.body1.shape.segments.length *
      collision.body2.shape.segments.length;

    let result: {
      intersections: [IntersectionPoint[], IntersectionPoint[]] | null;
      segmentCheckCount: number;
      segmentIntersectionCount: number;
    } | null = null;

    try {
      // Get intersection points
      result = findShapeIntersections(
        collision.segmentPairs,
        collision.relativeTransformInfo.guestToHost
      );
    } catch (e) {
      continue;
    }

    if (result.intersections !== null && result.intersections.length > 0) {
      segmentCheckCount += result.segmentCheckCount;
      segmentIntersectionCount += result.segmentIntersectionCount;
      collision.intersectionPoints = result.intersections;
      intersectionCollisions.push(collision);
    } else {
      possibleEnclosureCollisions.push(collision);
    }
  }

  return {
    intersectionCollisions,
    possibleEnclosureCollisions,
    unfilteredSegmentCheckCount,
    segmentCheckCount,
    segmentIntersectionCount,
  };
}

function findEnclosure(enclosurePossible: CollisionPair[]): CollisionPair[] {
  const enclosureCollisions: CollisionPair[] = [];
  for (const pair of enclosurePossible) {
    const enclosureType = findShapeEnclosure(
      pair.body1.shape,
      pair.body2.shape,
      pair.body1Transform,
      pair.body2Transform
    );

    if (enclosureType !== EnclosureType.None) {
      pair.enclosureType = enclosureType;
      enclosureCollisions.push(pair);
    }
  }

  return enclosureCollisions;
}

function calculateIntersectionCollisionInfo(
  collision: CollisionPair,
  finalCollisions: CollisionInfo[]
): void {
  const builder = new CollisionShapeBuilder(collision.relativeTransformInfo);

  const ok = buildIntersectionShapes({
    segments: [collision.body1.shape.segments, collision.body2.shape.segments],
    intersections: collision.intersectionPoints!,
    shapeBuilder: builder,
    union: false,
  });

  if (ok) {
    const result = builder.getFinalResult();
    finalCollisions.push({
      body1: collision.body1,
      body2: collision.body2,
      pointOfContact: result.centroid,
      normal: result.normal,
      penetrationDistance: result.penetrationDistance,
    });
  }
}

function calculateEnclosureCollisionInfo(
  collision: CollisionPair,
  finalCollisions: CollisionInfo[]
): void {
  const [vx1, vy1, w1] = collision.body1.getVelocity();
  const [vx2, vy2, w2] = collision.body2.getVelocity();

  const result =
    buildEnclosureShapes({
      shapes: [collision.body1.shape, collision.body2.shape],
      transforms: [collision.body1Transform, collision.body2Transform],
      motionInfo: {
        shape1: { vx: vx1, vy: vy1, angularVelocity: w1 },
        shape2: { vx: vx2, vy: vy2, angularVelocity: w2 },
      },
      union: false,
      enclosureType: collision.enclosureType,
      buildShapes: false,
    }) ?? null;

  if (result) {
    finalCollisions.push({
      body1: collision.body1,
      body2: collision.body2,
      pointOfContact: result.totalCentroid,
      normal: result.normal,
      penetrationDistance: result.penetrationDistance,
    });
  }
}

function resolveImpulseCollision(collision: CollisionInfo): boolean {
  const { body1, body2, normal, pointOfContact, penetrationDistance } =
    collision;

  if (!normal.isValid()) {
    return false;
  }

  if (!pointOfContact.isValid()) {
    return false;
  }

  if (penetrationDistance <= 0 || !isFinite(penetrationDistance)) {
    return false;
  }

  const {
    r: r1,
    v: v1,
    invMass: invMass1,
    invInertia: invInertia1,
  } = body1.getStateForCollision(pointOfContact);
  const {
    r: r2,
    v: v2,
    invMass: invMass2,
    invInertia: invInertia2,
  } = body2.getStateForCollision(pointOfContact);

  if (invMass1 === 0 && invMass2 === 0) {
    return false;
  }

  const relativeVelocity = v1.subtract(v2);

  if (!relativeVelocity.isValid()) {
    return false;
  }

  const normalVelocity = relativeVelocity.dot(normal);

  // Get the impact speed (magnitude of relative velocity along normal)
  const impactSpeed = Math.abs(normalVelocity);

  // Only proceed if objects are moving toward each other
  if (normalVelocity < 0) {
    const restitution = Math.min(
      body1.material.restitution,
      body2.material.restitution
    );

    // Find geometric mean of friction coefficients
    const staticFriction = Math.sqrt(
      body1.material.staticFriction * body2.material.staticFriction
    );
    const dynamicFriction = Math.sqrt(
      body1.material.dynamicFriction * body2.material.dynamicFriction
    );

    const tangent = new Vector2D(-normal.y, normal.x);
    const tangentVelocity = relativeVelocity.dot(tangent);

    const r1CrossN = r1.x * normal.y - r1.y * normal.x;
    const r2CrossN = r2.x * normal.y - r2.y * normal.x;
    const r1CrossT = r1.x * tangent.y - r1.y * tangent.x;
    const r2CrossT = r2.x * tangent.y - r2.y * tangent.x;

    const denominator =
      invMass1 +
      invMass2 +
      r1CrossN * r1CrossN * invInertia1 +
      r2CrossN * r2CrossN * invInertia2;

    // Avoid division by zero
    if (denominator === 0) {
      return false;
    }

    const j = (-(1 + restitution) * normalVelocity) / denominator;
    const normalImpulse = normal.scale(j);

    // Calculate friction
    const tangentDenominator =
      invMass1 +
      invMass2 +
      r1CrossT * r1CrossT * invInertia1 +
      r2CrossT * r2CrossT * invInertia2;

    // Avoid division by zero
    if (tangentDenominator === 0) {
      return false;
    }

    let jt = -tangentVelocity / tangentDenominator;
    const maxFriction = j * staticFriction;

    if (Math.abs(jt) > maxFriction) {
      jt = Math.sign(jt) * j * dynamicFriction;
    }

    const tangentImpulse = tangent.scale(jt);
    const totalImpulse = normalImpulse.add(tangentImpulse);

    // Calculate collision intensity metrics
    const impulseStrength = totalImpulse.magnitude();

    // Calculate kinetic energy of the collision
    const collisionEnergy = 0.5 * impulseStrength * impactSpeed;

    // Add these values to the collision info
    collision.intensity = {
      impulseStrength,
      impactSpeed,
      collisionEnergy,
    };

    body1.applyImpulse(totalImpulse, r1);
    body2.applyImpulse(totalImpulse.scale(-1), r2);
  }

  const BETA = 0.9; // Position correction factor
  const SLOP = 1; // Penetration tolerance

  if (penetrationDistance > SLOP) {
    const correction = normal.scale(
      ((penetrationDistance - SLOP) * BETA) / (invMass1 + invMass2 || 1)
    );

    if (invMass1 > 0) {
      body1.addPositionCorrection(correction);
    }
    if (invMass2 > 0) {
      body2.addPositionCorrection(correction.scale(-1));
    }
  }

  return true;
}
