import type { Body } from "./body.ts";
import type { SurfacePoint } from "./surfacePoint.ts";

export type PinConfig = {
  desiredLength: number;
  restitution: number;
  baumgarteScale: number;
  slop: number;
};

export class Pin {
  private surfacePoint1: SurfacePoint;
  private surfacePoint2: SurfacePoint;
  private desiredLength: number;
  private restitution: number;
  private baumgarteScale: number;
  private slop: number;

  constructor(
    surfacePoint1: SurfacePoint,
    surfacePoint2: SurfacePoint,
    config: PinConfig
  ) {
    this.surfacePoint1 = surfacePoint1;
    this.surfacePoint2 = surfacePoint2;
    this.desiredLength = config.desiredLength;
    this.restitution = config.restitution;
    this.baumgarteScale = config.baumgarteScale;
    this.slop = config.slop;
  }

  apply(): Body[] {
    const impulseInfo1 = this.surfacePoint1.getImpulseInfo();
    const impulseInfo2 = this.surfacePoint2.getImpulseInfo();

    if (impulseInfo1 === null || impulseInfo2 === null) {
      return [];
    }

    const {
      position: pos1,
      velocity: vel1,
      r: r1,
      inverseMass: invMass1,
      inverseMomentOfInertia: invInertia1,
    } = impulseInfo1;
    const {
      position: pos2,
      velocity: vel2,
      r: r2,
      inverseMass: invMass2,
      inverseMomentOfInertia: invInertia2,
    } = impulseInfo2;

    const correctedBodies: Body[] = [];

    // Skip if both objects have infinite mass
    if (invMass1 === 0 && invMass2 === 0) return [];

    // Calculate current length and direction
    const displacement = pos2.subtract(pos1);
    const currentLength = displacement.magnitude();

    const direction = displacement.scale(1 / currentLength);

    const relativeVelocity = vel2.subtract(vel1);

    // Project relative velocity onto constraint direction
    const normalVelocity = relativeVelocity.dot(direction);

    // Calculate r cross n terms for the denominator
    const r1CrossN = r1.x * direction.y - r1.y * direction.x;
    const r2CrossN = r2.x * direction.y - r2.y * direction.x;

    // Calculate denominator including angular terms
    const denominator =
      invMass1 +
      invMass2 +
      r1CrossN * r1CrossN * invInertia1 +
      r2CrossN * r2CrossN * invInertia2;

    if (denominator === 0) return [];

    if (denominator === 0) {
      return correctedBodies;
    }

    // Calculate impulse scalar
    const j = (-(1 + this.restitution) * normalVelocity) / denominator;

    // Apply impulses
    const impulse = direction.multiply(j);

    if (invMass1 > 0) {
      this.surfacePoint1.applyImpulse(impulse.multiply(-1));
    }
    if (invMass2 > 0) {
      this.surfacePoint2.applyImpulse(impulse);
    }

    // Position Correction
    const penetration = currentLength - this.desiredLength;
    const correctionMagnitude =
      Math.max(penetration - this.slop, 0) * this.baumgarteScale;
    const correction = direction.multiply(correctionMagnitude / denominator);

    if (invMass1 > 0) {
      this.surfacePoint1.addPositionCorrection(correction.multiply(invMass1));
      correctedBodies.push(this.surfacePoint1.getBody()!);
    }
    if (invMass2 > 0) {
      this.surfacePoint2.addPositionCorrection(correction.multiply(-invMass2));
      correctedBodies.push(this.surfacePoint2.getBody()!);
    }

    return correctedBodies;
  }

  render(
    ctx: CanvasRenderingContext2D | OffscreenCanvasRenderingContext2D
  ): void {
    const pos1 = this.surfacePoint1.getPosition();
    const pos2 = this.surfacePoint2.getPosition();

    if (pos1 === null || pos2 === null) {
      return;
    }

    ctx.strokeStyle = "red";
    ctx.lineWidth = 10;
    ctx.beginPath();
    ctx.moveTo(pos1.x, pos1.y);
    ctx.lineTo(pos2.x, pos2.y);
    ctx.stroke();
  }
}
