#include "td_main.h"
#include <raymath.h>

static Projectile projectiles[PROJECTILE_MAX_COUNT];
static int projectileCount = 0;

typedef struct ProjectileConfig
{
  float arcFactor;
  Color color;
  Color trailColor;
} ProjectileConfig;

ProjectileConfig projectileConfigs[] = {
    [PROJECTILE_TYPE_ARROW] = {
        .arcFactor = 0.15f,
        .color = RED,
        .trailColor = BROWN,
    },
    [PROJECTILE_TYPE_CATAPULT] = {
        .arcFactor = 0.5f,
        .color = RED,
        .trailColor = GRAY,
    },
    [PROJECTILE_TYPE_BALLISTA] = {
        .arcFactor = 0.025f,
        .color = RED,
        .trailColor = BROWN,
    },
};

void ProjectileInit()
{
  for (int i = 0; i < PROJECTILE_MAX_COUNT; i++)
  {
    projectiles[i] = (Projectile){0};
  }
}

void ProjectileDraw()
{
  for (int i = 0; i < projectileCount; i++)
  {
    Projectile projectile = projectiles[i];
    if (projectile.projectileType == PROJECTILE_TYPE_NONE)
    {
      continue;
    }
    float transition = (gameTime.time - projectile.shootTime) / (projectile.arrivalTime - projectile.shootTime);
    if (transition >= 1.0f)
    {
      continue;
    }

    ProjectileConfig config = projectileConfigs[projectile.projectileType];
    for (float transitionOffset = 0.0f; transitionOffset < 1.0f; transitionOffset += 0.1f)
    {
      float t = transition + transitionOffset * 0.3f;
      if (t > 1.0f)
      {
        break;
      }
      Vector3 position = Vector3Lerp(projectile.position, projectile.target, t);
      Color color = config.color;
      color = ColorLerp(config.trailColor, config.color, transitionOffset * transitionOffset);
      // fake a ballista flight path using parabola equation
      float parabolaT = t - 0.5f;
      parabolaT = 1.0f - 4.0f * parabolaT * parabolaT;
      position.y += config.arcFactor * parabolaT * projectile.distance;
      
      float size = 0.06f * (transitionOffset + 0.25f);
      DrawCube(position, size, size, size, color);
    }
  }
}

void ProjectileUpdate()
{
  for (int i = 0; i < projectileCount; i++)
  {
    Projectile *projectile = &projectiles[i];
    if (projectile->projectileType == PROJECTILE_TYPE_NONE)
    {
      continue;
    }
    float transition = (gameTime.time - projectile->shootTime) / (projectile->arrivalTime - projectile->shootTime);
    if (transition >= 1.0f)
    {
      projectile->projectileType = PROJECTILE_TYPE_NONE;
      Enemy *enemy = EnemyTryResolve(projectile->targetEnemy);
      if (enemy && projectile->hitEffectConfig.pushbackPowerDistance > 0.0f)
      {
          Vector2 direction = Vector2Normalize(Vector2Subtract((Vector2){projectile->target.x, projectile->target.z}, enemy->simPosition));
          enemy->simPosition = Vector2Add(enemy->simPosition, Vector2Scale(direction, projectile->hitEffectConfig.pushbackPowerDistance));
      }
      
      if (projectile->hitEffectConfig.areaDamageRadius > 0.0f)
      {
        EnemyAddDamageRange((Vector2){projectile->target.x, projectile->target.z}, projectile->hitEffectConfig.areaDamageRadius, projectile->hitEffectConfig.damage);
        // pancaked sphere explosion
        float r = projectile->hitEffectConfig.areaDamageRadius;
        ParticleAdd(PARTICLE_TYPE_EXPLOSION, projectile->target, (Vector3){0}, (Vector3){r, r * 0.2f, r}, 0.33f);
      }
      else if (projectile->hitEffectConfig.damage > 0.0f && enemy)
      {
        EnemyAddDamage(enemy, projectile->hitEffectConfig.damage);
      }
      continue;
    }
  }
}

Projectile *ProjectileTryAdd(uint8_t projectileType, Enemy *enemy, Vector3 position, Vector3 target, float speed, HitEffectConfig hitEffectConfig)
{
  for (int i = 0; i < PROJECTILE_MAX_COUNT; i++)
  {
    Projectile *projectile = &projectiles[i];
    if (projectile->projectileType == PROJECTILE_TYPE_NONE)
    {
      projectile->projectileType = projectileType;
      projectile->shootTime = gameTime.time;
      float distance = Vector3Distance(position, target);
      projectile->arrivalTime = gameTime.time + distance / speed;
      projectile->position = position;
      projectile->target = target;
      projectile->directionNormal = Vector3Scale(Vector3Subtract(target, position), 1.0f / distance);
      projectile->distance = distance;
      projectile->targetEnemy = EnemyGetId(enemy);
      projectileCount = projectileCount <= i ? i + 1 : projectileCount;
      projectile->hitEffectConfig = hitEffectConfig;
      return projectile;
    }
  }
  return 0;
}