import ExpiryMap from 'expiry-map';

export interface ICache<
  TCacheItem,
  TKey,
  TAdditionalArgs extends unknown[] = []
> {
  get(
    key: TKey,
    ...args: [...TAdditionalArgs, AbortSignal?]
  ): Promise<TCacheItem>;
  get(
    keys: TKey[],
    ...args: [...TAdditionalArgs, AbortSignal?]
  ): Promise<TCacheItem>[];
  get(
    keyOrKeys: TKey | TKey[],
    ...args: [...TAdditionalArgs, AbortSignal?]
  ): Promise<TCacheItem> | Promise<TCacheItem>[];
  set(key: TKey, value: TCacheItem): void;
  delete(key: TKey): void;
}

abstract class BaseCache<
  TCacheItem,
  TKey,
  TTransformedKey = string,
  TAdditionalArgs extends unknown[] = [],
  TResult = TCacheItem
> {
  protected readonly keyTransformer: (key: TKey) => TTransformedKey;
  protected readonly fetcher:
    | {
        fetchOneFn: (
          key: TKey,
          ...args: [...TAdditionalArgs, AbortSignal]
        ) => Promise<TCacheItem>;
      }
    | {
        fetchManyFn: (
          key: TKey[],
          ...args: [...TAdditionalArgs, AbortSignal]
        ) => Promise<TResult[]>;
        resultSelector: (items: TResult[], key: TKey) => TCacheItem;
      };
  constructor(
    options: {
      keyTransformer?: (key: TKey) => TTransformedKey;
    } & (
      | {
          fetchOneFn: (
            key: TKey,
            ...args: [...TAdditionalArgs, AbortSignal]
          ) => Promise<TCacheItem>;
        }
      | {
          fetchManyFn: (
            keys: TKey[],
            ...args: [...TAdditionalArgs, AbortSignal]
          ) => Promise<TResult[]>;
          resultSelector: (items: TResult[], key: TKey) => TCacheItem;
        }
    )
  ) {
    this.fetcher = options;
    if (options.keyTransformer) {
      this.keyTransformer = options.keyTransformer;
    } else {
      this.keyTransformer = (k) => k as unknown as TTransformedKey;
    }
  }

  public get(
    key: TKey,
    ...args: [...TAdditionalArgs, AbortSignal?]
  ): Promise<TCacheItem>;
  public get(
    keys: TKey[],
    ...args: [...TAdditionalArgs, AbortSignal?]
  ): Promise<TCacheItem>[];
  public get(
    keyOrKeys: TKey | TKey[],
    ...args: [...TAdditionalArgs, AbortSignal?]
  ): Promise<TCacheItem> | Promise<TCacheItem>[] {
    const abortSignal: AbortSignal =
      (args.length > 0
        ? (args[args.length - 1] as AbortSignal | undefined)
        : null) ?? new AbortController().signal;
    if (keyOrKeys instanceof Array) {
      return this.getMany(
        keyOrKeys,
        ...(args.slice(0, -1) as TAdditionalArgs),
        abortSignal
      );
    } else {
      return this.getOne(
        keyOrKeys,
        ...(args.slice(0, -1) as TAdditionalArgs),
        abortSignal
      );
    }
  }

  protected abstract getOne(
    key: TKey,
    ...args: [...TAdditionalArgs, AbortSignal]
  ): Promise<TCacheItem>;
  protected abstract getMany(
    keys: TKey[],
    ...args: [...TAdditionalArgs, AbortSignal]
  ): Promise<TCacheItem>[];
  abstract set(key: TKey, value: TCacheItem): void;
  abstract delete(key: TKey): void;
}

export class MemoryCache<
    TCacheItem,
    TKey,
    TTransformedKey = string,
    TAdditionalArgs extends unknown[] = [],
    TResult = TCacheItem
  >
  extends BaseCache<TCacheItem, TKey, TTransformedKey, TAdditionalArgs, TResult>
  implements ICache<TCacheItem, TKey, TAdditionalArgs>
{
  private readonly rollingExpiration: boolean;
  private readonly cache:
    | Map<TTransformedKey, Promise<TCacheItem>>
    | ExpiryMap<TTransformedKey, Promise<TCacheItem>>;
  private readonly maxEntries?: number;
  constructor(
    options: {
      cacheExpirationMs?: number;
      maxEntries?: number;
      rollingExpiration?: boolean;
      keyTransformer?: (key: TKey) => TTransformedKey;
    } & (
      | {
          fetchOneFn: (
            key: TKey,
            ...args: [...TAdditionalArgs, AbortSignal]
          ) => Promise<TCacheItem>;
        }
      | {
          fetchManyFn: (
            keys: TKey[],
            ...args: [...TAdditionalArgs, AbortSignal]
          ) => Promise<TResult[]>;
          resultSelector: (items: TResult[], key: TKey) => TCacheItem;
        }
    )
  ) {
    super(options);
    if (options.cacheExpirationMs == null) {
      this.cache = new Map();
    } else {
      this.cache = new ExpiryMap(options.cacheExpirationMs);
    }
    this.rollingExpiration = options.rollingExpiration ?? false;
  }

  protected getOne(
    key: TKey,
    ...args: [...TAdditionalArgs, AbortSignal]
  ): Promise<TCacheItem> {
    const transformedKey = this.keyTransformer(key);
    let cachedItem = this.cache.get(transformedKey);
    if (!cachedItem) {
      let fetchPromise: Promise<TCacheItem>;
      if ('fetchOneFn' in this.fetcher) {
        fetchPromise = this.fetcher.fetchOneFn(key, ...args);
      } else {
        const selector = this.fetcher.resultSelector;
        fetchPromise = this.fetcher
          .fetchManyFn([key], ...args)
          .then((items) => selector(items, key));
      }
      this.cache.set(transformedKey, fetchPromise);
      cachedItem = fetchPromise;

      // Trim oversize entries
      while (this.maxEntries && this.cache.size > this.maxEntries) {
        const firstKey = this.cache.keys().next().value;
        this.cache.delete(firstKey);
      }
    } else if (this.rollingExpiration) {
      // Re-set the cache item to refresh the expiration/position
      this.cache.delete(transformedKey);
      this.cache.set(transformedKey, cachedItem);
    }
    return cachedItem;
  }

  protected getMany(
    keys: TKey[],
    ...args: [...TAdditionalArgs, AbortSignal]
  ): Promise<TCacheItem>[] {
    const keyPromisesMap = new Map<TKey, Promise<TCacheItem> | undefined>();
    for (const key of keys) {
      const transformedKey = this.keyTransformer(key);
      keyPromisesMap.set(key, this.cache.get(transformedKey));
    }

    const keysToFetch = Array.from(keyPromisesMap.entries())
      .filter(([, v]) => v == null)
      .map(([k]) => k);

    if (keysToFetch.length) {
      if ('fetchOneFn' in this.fetcher) {
        for (const key of keysToFetch) {
          const transformedKey = this.keyTransformer(key);
          const fetchPromise = this.fetcher
            .fetchOneFn(key, ...args)
            .catch((err) => {
              this.cache.delete(transformedKey);
              throw err;
            });
          this.cache.set(transformedKey, fetchPromise);
          keyPromisesMap.set(key, fetchPromise);
        }
      } else {
        const baseFetch = this.fetcher.fetchManyFn(keysToFetch, ...args);
        for (const key of keysToFetch) {
          const transformedKey = this.keyTransformer(key);
          const selector = this.fetcher.resultSelector;
          const fetchPromise = baseFetch
            .then((items) => selector(items, key))
            .catch((err) => {
              this.cache.delete(transformedKey);
              throw err;
            });
          this.cache.set(transformedKey, fetchPromise);
          keyPromisesMap.set(key, fetchPromise);
        }
      }
    }

    // Trim oversize entries
    while (this.maxEntries && this.cache.size > this.maxEntries) {
      const firstKey = this.cache.keys().next().value;
      this.cache.delete(firstKey);
    }

    if (this.rollingExpiration) {
      const existingEntries = Array.from(keyPromisesMap.entries()).filter(
        (entry): entry is [TKey, Promise<TCacheItem>] => entry[1] != null
      );

      for (const [key, value] of existingEntries) {
        const transformedKey = this.keyTransformer(key);
        // Re-set the cache item to refresh the expiration/position
        this.cache.delete(transformedKey);
        this.cache.set(transformedKey, value);
      }
    }

    return Array.from(keyPromisesMap.values()).filter(
      (v): v is Promise<TCacheItem> => {
        if (v != null) {
          return true;
        } else {
          throw new Error('Unexpected null value in keyPromisesMap');
        }
      }
    );
  }

  public set(key: TKey, value: TCacheItem) {
    const transformedKey = this.keyTransformer(key);
    this.cache.set(transformedKey, Promise.resolve(value));
  }

  public delete(key: TKey) {
    const transformedKey = this.keyTransformer(key);
    this.cache.delete(transformedKey);
  }
}

/** During the course of a request, return the promise, but remove it from
 * cache as soon as the request finishes. Good for requests that will be
 * persisted on disk cache. */
export class NoCache<
    TCacheItem,
    TKey,
    TTransformedKey = string,
    TAdditionalArgs extends unknown[] = [],
    TResult = TCacheItem
  >
  extends MemoryCache<
    TCacheItem,
    TKey,
    TTransformedKey,
    TAdditionalArgs,
    TResult
  >
  implements ICache<TCacheItem, TKey, TAdditionalArgs>
{
  constructor(
    opts: {
      maxEntries?: number;
      keyTransformer?: (key: TKey) => TTransformedKey;
    } & (
      | {
          fetchOneFn: (
            key: TKey,
            ...args: [...TAdditionalArgs, AbortSignal]
          ) => Promise<TCacheItem>;
        }
      | {
          fetchManyFn: (
            keys: TKey[],
            ...args: [...TAdditionalArgs, AbortSignal]
          ) => Promise<TResult[]>;
          resultSelector: (items: TResult[], key: TKey) => TCacheItem;
        }
    )
  ) {
    super(opts);
  }

  protected override getOne(
    key: TKey,
    ...args: [...TAdditionalArgs, AbortSignal]
  ): Promise<TCacheItem> {
    return super.getOne(key, ...args).finally(() => {
      this.delete(key);
    });
  }

  protected override getMany(
    keys: TKey[],
    ...args: [...TAdditionalArgs, AbortSignal]
  ): Promise<TCacheItem>[] {
    return super.getMany(keys, ...args).map((p, i) =>
      p.finally(() => {
        this.delete(keys[i]);
      })
    );
  }
}
