import { buildEnclosureShapes } from "../geometry/buildEnclosureShapes";
import { buildIntersectionShapes } from "../geometry/buildIntersectionShapes";
import {
  findShapeEnclosure,
  findShapeIntersections,
} from "../geometry/intersection";
import { Vector2D } from "../math/vector2D";
import {
  EnclosureType,
  type CollisionInfo,
  type CollisionPair,
  type IntersectionPoint,
} from "../models";
import { CollisionShapeBuilder } from "../shapes/builders/collisionShapeBuilder";
import { intersectSegmentTrees } from "../shapes/intersectSegmentTrees";
import type { Body } from "./body";

export function getActualCollisions(
  potentialCollisions: CollisionPair[]
): CollisionInfo[] {
  const segmentPairCount = findSegmentPairs(potentialCollisions);

  const {
    intersectionCollisions,
    possibleEnclosureCollisions,
    unfilteredSegmentCheckCount,
    segmentCheckCount,
    segmentIntersectionCount,
  } = intersectSegmentPairs(potentialCollisions);

  const enclosureCollisions = findEnclosure(possibleEnclosureCollisions);

  const finalCollisions: CollisionInfo[] = [];

  for (const collision of intersectionCollisions) {
    calculateIntersectionCollisionInfo(collision, finalCollisions);
  }

  for (const collision of enclosureCollisions) {
    calculateEnclosureCollisionInfo(collision, finalCollisions);
  }

  return finalCollisions;
}

export function resolveCollisions(
  collisions: CollisionInfo[],
  bodiesWithPositionCorrections: Set<Body>,
  triggers: (() => void)[]
): void {
  for (const collision of collisions) {
    const body1 = collision.body1;
    const body2 = collision.body2;

    const hasTriggers1 = body1.addCollision(collision);
    if (hasTriggers1) {
      triggers.push(body1.callCollisionTriggers);
    }

    const hasTriggers2 = body2.addCollision(collision);
    if (hasTriggers2) {
      triggers.push(body2.callCollisionTriggers);
    }

    if (
      !body1.shouldResolveCollision(collision) ||
      !body2.shouldResolveCollision(collision)
    ) {
      continue;
    }

    resolveImpulseCollision(collision);

    bodiesWithPositionCorrections.add(body1);
    bodiesWithPositionCorrections.add(body2);
  }
}

function findSegmentPairs(potentialCollisions: CollisionPair[]): number {
  let segmentPairCount = 0;

  for (const collision of potentialCollisions) {
    const results = intersectSegmentTrees(
      collision.body1.shape.getSegmentTree(),
      collision.body2.shape.getSegmentTree(),
      collision.relativeTransformInfo.guestToHost,
      collision.relativeTransformInfo.hostToGuest
    );

    if (results.length > 0) {
      segmentPairCount += results.length;
      collision.segmentPairs = results;
    }
  }
  return segmentPairCount;
}

function intersectSegmentPairs(potentialCollisions: CollisionPair[]): {
  intersectionCollisions: CollisionPair[];
  possibleEnclosureCollisions: CollisionPair[];
  unfilteredSegmentCheckCount: number;
  segmentCheckCount: number;
  segmentIntersectionCount: number;
} {
  let unfilteredSegmentCheckCount = 0;
  let segmentCheckCount = 0;
  let segmentIntersectionCount = 0;

  const intersectionCollisions: CollisionPair[] = [];
  const possibleEnclosureCollisions: CollisionPair[] = [];

  for (const collision of potentialCollisions) {
    if (collision.segmentPairs.length === 0) {
      possibleEnclosureCollisions.push(collision);
      continue;
    }

    unfilteredSegmentCheckCount +=
      collision.body1.shape.segments.length *
      collision.body2.shape.segments.length;

    let result: {
      intersections: [IntersectionPoint[], IntersectionPoint[]] | null;
      segmentCheckCount: number;
      segmentIntersectionCount: number;
    } | null = null;

    try {
      // Get intersection points
      result = findShapeIntersections(
        collision.segmentPairs,
        collision.relativeTransformInfo.guestToHost
      );
    } catch (e) {
      continue;
    }

    if (result.intersections !== null && result.intersections.length > 0) {
      segmentCheckCount += result.segmentCheckCount;
      segmentIntersectionCount += result.segmentIntersectionCount;
      collision.intersectionPoints = result.intersections;
      intersectionCollisions.push(collision);
    } else {
      possibleEnclosureCollisions.push(collision);
    }
  }

  return {
    intersectionCollisions,
    possibleEnclosureCollisions,
    unfilteredSegmentCheckCount,
    segmentCheckCount,
    segmentIntersectionCount,
  };
}

function findEnclosure(enclosurePossible: CollisionPair[]): CollisionPair[] {
  const enclosureCollisions: CollisionPair[] = [];
  for (const pair of enclosurePossible) {
    const enclosureType = findShapeEnclosure(
      pair.body1.shape,
      pair.body2.shape,
      pair.body1Transform,
      pair.body2Transform,
      pair.body1Mec,
      pair.body2Mec
    );

    if (enclosureType !== EnclosureType.None) {
      pair.enclosureType = enclosureType;
      enclosureCollisions.push(pair);
    }
  }

  return enclosureCollisions;
}

function calculateIntersectionCollisionInfo(
  collision: CollisionPair,
  finalCollisions: CollisionInfo[]
): void {
  const builder = new CollisionShapeBuilder(collision.relativeTransformInfo);

  const ok = buildIntersectionShapes({
    segments: [collision.body1.shape.segments, collision.body2.shape.segments],
    intersections: collision.intersectionPoints!,
    shapeBuilder: builder,
    union: false,
  });

  if (ok) {
    const result = builder.getFinalResult();
    finalCollisions.push({
      body1: collision.body1,
      body2: collision.body2,
      pointOfContact: result.centroid,
      normal: result.normal,
      penetrationDistance: result.penetrationDistance,
    });
  }
}

function calculateEnclosureCollisionInfo(
  collision: CollisionPair,
  finalCollisions: CollisionInfo[]
): void {
  const [vx1, vy1, w1] = collision.body1.getVelocity();
  const [vx2, vy2, w2] = collision.body2.getVelocity();

  const result =
    buildEnclosureShapes({
      shapes: [collision.body1.shape, collision.body2.shape],
      transforms: [collision.body1Transform, collision.body2Transform],
      motionInfo: {
        shape1: { vx: vx1, vy: vy1, angularVelocity: w1 },
        shape2: { vx: vx2, vy: vy2, angularVelocity: w2 },
      },
      union: false,
      enclosureType: collision.enclosureType,
      buildShapes: false,
    }) ?? null;

  if (result) {
    finalCollisions.push({
      body1: collision.body1,
      body2: collision.body2,
      pointOfContact: result.totalCentroid,
      normal: result.normal,
      penetrationDistance: result.penetrationDistance,
    });
  }
}

function resolveImpulseCollision(collision: CollisionInfo): void {
  const { body1, body2, normal, pointOfContact, penetrationDistance } =
    collision;

  if (!normal.isValid()) {
    console.warn("normal is invalid", body1.id, body2.id);
    return;
  }

  if (!pointOfContact.isValid()) {
    return;
  }

  if (penetrationDistance <= 0 || !isFinite(penetrationDistance)) {
    // console.warn("penetrationDistance is invalid", penetrationDistance);
    return;
  }

  const restitution = Math.min(
    body1.material.restitution,
    body2.material.restitution
  );

  // Find geometric mean of friction coefficients
  const staticFriction = Math.sqrt(
    body1.material.staticFriction * body2.material.staticFriction
  );
  const dynamicFriction = Math.sqrt(
    body1.material.dynamicFriction * body2.material.dynamicFriction
  );

  const mass1 = body1.getMass(true);
  const mass2 = body2.getMass(true);

  // Calculate inverse masses (0 for infinite mass)
  const invMass1 = mass1 === Infinity ? 0 : 1 / mass1;
  const invMass2 = mass2 === Infinity ? 0 : 1 / mass2;

  // Skip if both objects have infinite mass
  if (invMass1 === 0 && invMass2 === 0) {
    return;
  }

  const inertia1 = body1.getMomentOfInertia(true);
  const inertia2 = body2.getMomentOfInertia(true);

  // Calculate centers and moment arms
  const center1 = body1.getCenter(true);
  const center2 = body2.getCenter(true);
  const r1 = pointOfContact.subtract(center1);
  const r2 = pointOfContact.subtract(center2);

  // Get point velocities
  const v1Point = body1.getPointVelocity(r1, true);
  const v2Point = body2.getPointVelocity(r2, true);
  const relativeVelocity = v1Point.subtract(v2Point);

  if (!relativeVelocity.isValid()) {
    console.error(
      `relativeVelocity is invalid for pair ${body1.id} and ${body2.id}: ${relativeVelocity.x}, ${relativeVelocity.y}\nr1: ${r1.x}, ${r1.y}\nr2: ${r2.x}, ${r2.y}\nv1Point: ${v1Point.x}, ${v1Point.y}\nv2Point: ${v2Point.x}, ${v2Point.y}\nmass1: ${mass1}\nmass2: ${mass2}\ninertia1: ${inertia1}\ninertia2: ${inertia2}\nnormal: ${normal.x}, ${normal.y}\ncenter1: ${center1.x}, ${center1.y}\ncenter2: ${center2.x}, ${center2.y}\npointOfContact: ${pointOfContact.x}, ${pointOfContact.y}`
    );
    return;
  }

  const normalVelocity = relativeVelocity.dot(normal);

  // Only proceed if objects are moving toward each other
  if (normalVelocity < 0) {
    const tangent = new Vector2D(-normal.y, normal.x);
    const tangentVelocity = relativeVelocity.dot(tangent);

    const r1CrossN = r1.x * normal.y - r1.y * normal.x;
    const r2CrossN = r2.x * normal.y - r2.y * normal.x;
    const r1CrossT = r1.x * tangent.y - r1.y * tangent.x;
    const r2CrossT = r2.x * tangent.y - r2.y * tangent.x;

    // Calculate inverse inertias (0 for locked rotation)
    const invInertia1 = body1.rotationLocked ? 0 : 1 / inertia1;
    const invInertia2 = body2.rotationLocked ? 0 : 1 / inertia2;

    const denominator =
      invMass1 +
      invMass2 +
      r1CrossN * r1CrossN * invInertia1 +
      r2CrossN * r2CrossN * invInertia2;

    // Avoid division by zero
    if (denominator === 0) {
      return;
    }

    const j = (-(1 + restitution) * normalVelocity) / denominator;
    const normalImpulse = normal.scale(j);

    // Calculate friction
    const tangentDenominator =
      invMass1 +
      invMass2 +
      r1CrossT * r1CrossT * invInertia1 +
      r2CrossT * r2CrossT * invInertia2;

    // Avoid division by zero
    if (tangentDenominator === 0) {
      return;
    }

    let jt = -tangentVelocity / tangentDenominator;
    const maxFriction = j * staticFriction;

    if (Math.abs(jt) > maxFriction) {
      jt = Math.sign(jt) * j * dynamicFriction;
    }

    const tangentImpulse = tangent.scale(jt);
    const totalImpulse = normalImpulse.add(tangentImpulse);

    // Apply impulses using inverse mass
    if (invMass1 > 0) {
      body1.applyLinearImpulse(totalImpulse);
    }
    if (invMass2 > 0) {
      body2.applyLinearImpulse(totalImpulse.scale(-1));
    }

    // Apply angular impulses only if rotation isn't locked
    if (invInertia1 > 0) {
      body1.applyAngularImpulse(r1, totalImpulse);
    }
    if (invInertia2 > 0) {
      body2.applyAngularImpulse(r2, totalImpulse.scale(-1));
    }
  }

  const BETA = 0.9; // Position correction factor
  const SLOP = 1; // Penetration tolerance

  if (penetrationDistance > SLOP) {
    const correction = normal.scale(
      ((penetrationDistance - SLOP) * BETA) / (invMass1 + invMass2 || 1)
    );

    if (invMass1 > 0) {
      body1.addPositionCorrection(correction);
    }
    if (invMass2 > 0) {
      body2.addPositionCorrection(correction.scale(-1));
    }
  }
}
