import { BN_DIV_NUMERATOR_MULTIPLIER_DECIMALS, BN_ZERO, fromBN } from '@mangata-finance/sdk';
import { ApiPromise } from '@polkadot/api';
import { BN } from '@polkadot/util';
import { QueryFunctionContext, UseQueryResult } from '@tanstack/react-query';
import { InjectedWallet, QueryOptional } from '../../../../services';
import { TransactionStore, ExtrinsicTx, TxType } from '../../../transaction';
import { Asset } from '../../../token';

export type GetRewardsAmountQueryKey = Readonly<
  [
    queryKey: string,
    lpId: string | undefined,
    address: string | undefined,
    finalizedTxCount: number,
  ]
>;

type ClaimRewardsFeeQueryKey = Readonly<
  [id: QueryOptional<string>, address: QueryOptional<string>, symbol: QueryOptional<string>]
>;

type ClaimAllRewardsFeeQueryKey = Readonly<
  [queryKey: string, address: string | undefined, finalizedTxCount: number, rewardsFetched: boolean]
>;

export const getRewardsAmount =
  (api: ApiPromise | null) =>
  async ({ queryKey: [, id, address] }: QueryFunctionContext<GetRewardsAmountQueryKey>) => {
    if (!api || !address || !id) {
      return null;
    }

    const amount = await api.rpc.xyk.calculate_rewards_amount(address, id);

    return amount;
  };

export const claimRewards =
  (
    api: ApiPromise | null,
    rewardsByLpId: Record<string, UseQueryResult<QueryOptional<BN>>>,
    address: string | undefined,
    wallet: QueryOptional<InjectedWallet>,
    transactionStore: TransactionStore,
    asset: QueryOptional<Asset>,
  ) =>
  async (liquidityTokenId: string) => {
    if (!api || !address || !wallet || !asset) {
      return null;
    }
    const amount = fromBN(
      rewardsByLpId[liquidityTokenId].data ?? BN_ZERO,
      BN_DIV_NUMERATOR_MULTIPLIER_DECIMALS,
    );

    const tx = api.tx.proofOfStake.claimNativeRewards(liquidityTokenId);
    return new ExtrinsicTx(api, transactionStore, wallet, address)
      .create(TxType.Claim)
      .setMetadata({ tokens: [{ ...asset, amount }] })
      .setTx(tx)
      .build()
      .send();
  };

export const claimAllRewards =
  (
    api: ApiPromise | null,
    rewardsByLpId: Record<string, UseQueryResult<QueryOptional<BN>>>,
    address: string | undefined,
    wallet: QueryOptional<InjectedWallet>,
    transactionStore: TransactionStore,
    asset: QueryOptional<Asset>,
  ) =>
  async () => {
    if (!api || !address || !wallet || !asset) {
      return null;
    }

    const tx = getClaimAllRewardsExtrinsic(api, rewardsByLpId);
    return new ExtrinsicTx(api, transactionStore, wallet, address)
      .create(TxType.Claim)
      .setTx(tx)
      .build()

      .send();
  };

export const claimAllRewardsFee =
  (api: ApiPromise | null, rewardsByLpId: Record<string, UseQueryResult<QueryOptional<BN>>>) =>
  async ({ queryKey: [, address] }: QueryFunctionContext<ClaimAllRewardsFeeQueryKey>) => {
    if (!address || !api || !rewardsByLpId) {
      return null;
    }

    const feePromise = Object.keys(rewardsByLpId)
      .filter((key) => rewardsByLpId[key].data?.gt(BN_ZERO))
      .map((id) => api.tx.proofOfStake.claimNativeRewards(id).paymentInfo(address));

    const fee = await Promise.all(feePromise);
    const batchFee = await api.tx.utility.batchAll([]).paymentInfo(address);

    const sum = fee
      .reduce((acc, curr) => acc.add(curr.partialFee), BN_ZERO)
      .add(batchFee.partialFee);

    return sum;
  };

export const claimRewardsFee =
  (api: ApiPromise | null) =>
  async ({ queryKey: [, id, address] }: QueryFunctionContext<ClaimRewardsFeeQueryKey>) => {
    if (!api || !address || !id) {
      return null;
    }

    const feeInfo = await api.tx.proofOfStake.claimNativeRewards(id).paymentInfo(address);

    return fromBN(feeInfo.partialFee, BN_DIV_NUMERATOR_MULTIPLIER_DECIMALS);
  };

function getClaimAllRewardsExtrinsic(
  api: ApiPromise,
  rewardsByLpId: Record<string, UseQueryResult<QueryOptional<BN>>>,
) {
  const txList = Object.keys(rewardsByLpId)
    .filter((key) => rewardsByLpId[key].data?.gt(BN_ZERO))
    .map((key) => {
      return api.tx.proofOfStake.claimNativeRewards(key);
    });

  const tx = api.tx.utility.batchAll(txList);

  return tx;
}
