import { GENERATING_AES_ALGORITHM, TRACK_EVENTS } from "core/consts";
import { flagPrivateKey50K } from "core/model/utils/featureFlags";
import {
  Account,
  EncryptedContent,
  RCryptoKey,
  SessionKey,
  ToType,
  TrackEventFn,
} from "core/types";
import { differenceInMilliseconds } from "date-fns";
import "fast-text-encoding";
import {
  AESDecrypt,
  AESEncrypt,
  RSADecrypt,
  RSAEncrypt,
  base64ToBytes,
  bytesToBase64,
  deriveAESKey,
  exportKey,
  generateAESKey,
  generateRSAKeysJWK,
  generateSalt,
  importAESKey,
  importPrivateKey,
  importPublicKey,
} from "./index";

const textDecoder = new TextDecoder("utf-8");
const textEncoder = new TextEncoder();

export async function encryptMessage(message: string, sessionKey: CryptoKey) {
  const iv = await generateSalt();

  return {
    encryptedMessage: bytesToBase64(
      await AESEncrypt(textEncoder.encode(message), sessionKey, iv),
    ),
    iv: bytesToBase64(iv),
  };
}

export async function decryptSessionKey(
  encryptedSessionKey: CryptoKey | string,
  parsedPrivateKey: CryptoKey,
  algorithm: string | null,
): Promise<CryptoKey | null> {
  try {
    const decryptedKey = await RSADecrypt(
      base64ToBytes(encryptedSessionKey),
      parsedPrivateKey,
    );
    return importAESKey(decryptedKey, algorithm);
  } catch (error) {
    console.error("Could not decrypt session key", encryptedSessionKey, error);
  }

  return null;
}

function startWithIv(bytesIV: Uint8Array, bytesMessage: Uint8Array) {
  if (bytesIV.length > bytesMessage.length) return false;
  for (let i = 0; i < bytesIV.length; i++)
    if (bytesMessage[i] != bytesIV[i]) return false;

  return true;
}

export async function decryptMessage({
  iv,
  message,
  sessionKey,
}: {
  iv: string;
  message: string;
  sessionKey: CryptoKey;
}) {
  try {
    const bytesMessage = base64ToBytes(message);
    const bytesIV = base64ToBytes(iv);

    const removeIV = startWithIv(bytesIV, bytesMessage);

    const decrypted = await AESDecrypt(bytesMessage, sessionKey, bytesIV);

    let decryptedArray: Uint8Array;

    if (removeIV) {
      decryptedArray = new Uint8Array(decrypted).slice(
        16,
        new Uint8Array(decrypted).length,
      );
    } else decryptedArray = new Uint8Array(decrypted);

    return textDecoder.decode(decryptedArray);
  } catch (e) {
    console.error("Could not decrypt message", e);
    return "";
  }
}

// Decrypt an encrypted session using the user's private key
export async function getSessionAccess({
  context,
  privateKey,
  sessionKey,
  trackEvent,
}: {
  context: AnyObject;
  privateKey: string;
  sessionKey: SessionKey | null;
  trackEvent: TrackEventFn;
}): Promise<CryptoKey | null | undefined> {
  if (sessionKey?.session_key) {
    const importedPrivateKey = await importPrivateKey(privateKey);

    if (!importedPrivateKey) {
      const reason = "bad private key";
      console.error(TRACK_EVENTS.SESSION_KEY_DECRYPTION, reason, context);
      trackEvent({
        name: TRACK_EVENTS.SESSION_KEY_DECRYPTION,
        status: "error",
        reason,
        system_event: true,
        ...context,
      });

      return null;
    }

    const decryptedSession = await decryptSessionKey(
      sessionKey.session_key,
      importedPrivateKey,
      sessionKey?.algorithm,
    );

    if (!decryptedSession) {
      const reason = "decryption error";
      console.error(TRACK_EVENTS.SESSION_KEY_DECRYPTION, reason, context);
      trackEvent({
        name: TRACK_EVENTS.SESSION_KEY_DECRYPTION,
        status: "error",
        system_event: true,
        reason,
        ...context,
      });

      return null;
    }

    trackEvent({
      name: TRACK_EVENTS.SESSION_KEY_DECRYPTION,
      system_event: true,
      status: "success",
      ...context,
    });

    console.log("%c Decrypted session key", "background: green; color: white");
    return decryptedSession;
  }
  return undefined;
}

export async function encryptSessionKey(
  sessionKey: BufferSource,
  publicKey: CryptoKey,
) {
  return bytesToBase64(await RSAEncrypt(sessionKey, publicKey));
}

export async function decryptJson(
  { content, iv }: Pick<EncryptedContent, "content" | "iv">,
  decryptedSessionKey: CryptoKey | null,
) {
  if (!content || !iv || !decryptedSessionKey) return null;

  const json = await decryptMessage({
    message: content,
    sessionKey: decryptedSessionKey,
    iv,
  });
  return JSON.parse(json);
}

export async function getEncryptedMessage(
  message: string,
  decryptedSessionKey: CryptoKey | null | undefined,
): Promise<{
  encryptedMessage: string;
  message_iv?: string;
}> {
  if (!message || message.length == 0 || !decryptedSessionKey)
    return { encryptedMessage: "" };

  const { encryptedMessage, iv } = await encryptMessage(
    message,
    decryptedSessionKey,
  );

  return { encryptedMessage, message_iv: iv };
}

export async function generatePasswordKey({
  iterations,
  password,
  salt,
  trackEvent,
}: {
  iterations: number;
  password: string;
  salt: Uint8Array;
  trackEvent: TrackEventFn;
}) {
  if (salt instanceof Uint8Array === false) {
    throw new Error("salt must be of type Uint8Array");
  }

  return deriveAESKey({
    password: textEncoder.encode(password),
    salt,
    iterations,
    trackEvent,
  });
}

export async function generateSessionKeys(
  users: Array<{ id: number; public_key: CryptoKey }>,
): Promise<
  {
    account_id: number;
    algorithm: string;
    session_key: string;
  }[]
> {
  const sessionKeyBytes = await generateAESKey(GENERATING_AES_ALGORITHM).then(
    exportKey,
  );
  return Promise.all(
    users.map(async (user) => {
      const publicKey = user.public_key;
      return {
        session_key: await encryptSessionKey(
          sessionKeyBytes,
          typeof publicKey === "string"
            ? await importPublicKey(publicKey)
            : publicKey,
        ),
        algorithm: GENERATING_AES_ALGORITHM,
        account_id: user.id,
      };
    }),
  );
}

export async function generateExtraSessionKey(
  sessionKey: CryptoKey,
  user: { id: number; public_key: CryptoKey },
) {
  const publicKey = user.public_key;

  const sessionKeyBytes = await exportKey(sessionKey);
  return {
    session_key: await encryptSessionKey(
      sessionKeyBytes,
      typeof publicKey === "string"
        ? await importPublicKey(publicKey)
        : publicKey,
    ),
    account_id: user.id,
  };
}

function checkKey(privateKey: JsonWebKey) {
  if (privateKey.d?.startsWith("0")) return false;
  if (privateKey.dp?.startsWith("0")) return false;
  if (privateKey.dq?.startsWith("0")) return false;
  if (privateKey.qi?.startsWith("0")) return false;
  return true;
}

// There are some inconsistencies between browsers on leading0.
// The fix is to retry generating the keys until they are valid.
// https://bugs.chromium.org/p/chromium/issues/detail?id=383998
export async function generateRSAKeysJWKWithoutLeading0() {
  const first = await generateRSAKeysJWK();
  if (checkKey(first.privateKey)) return first;

  const second = await generateRSAKeysJWK();
  if (checkKey(second.privateKey)) return second;

  const third = await generateRSAKeysJWK();
  if (checkKey(third.privateKey)) return third;

  const fourth = await generateRSAKeysJWK();
  if (checkKey(fourth.privateKey)) return fourth;

  console.error("Private key with leading after four tries");
  return fourth;
}

export async function encryptPrivateKeyIterationsString({
  iterations,
  password,
  privateIv,
  privateKey,
  salt,
  trackEvent,
}: {
  iterations: number;
  password: string;
  privateIv: Uint8Array;
  privateKey: string;
  salt: Uint8Array;
  trackEvent: TrackEventFn;
}) {
  const passwordKey = await generatePasswordKey({
    password,
    salt,
    iterations,
    trackEvent,
  });

  const encryptedPrivateKey = await AESEncrypt(
    textEncoder.encode(privateKey),
    passwordKey,
    privateIv,
  );

  return bytesToBase64(encryptedPrivateKey);
}

export async function encryptPrivateKeyIterations({
  iterations,
  password,
  privateIv,
  privateKey,
  salt,
  trackEvent,
}: {
  iterations: number;
  password: string;
  privateIv: Uint8Array;
  privateKey: JsonWebKey;
  salt: Uint8Array;
  trackEvent: TrackEventFn;
}) {
  const passwordKey = await generatePasswordKey({
    password,
    salt,
    iterations,
    trackEvent,
  });

  const encryptedPrivateKey = await AESEncrypt(
    textEncoder.encode(JSON.stringify(privateKey)),
    passwordKey,
    privateIv,
  );

  return bytesToBase64(encryptedPrivateKey);
}

export async function decryptPrivateKeyIterations({
  iterations,
  password,
  privateIv,
  privateKey,
  salt,
  trackEvent,
}: {
  iterations: number;
  password: string;
  privateIv: Uint8Array;
  privateKey: string;
  salt: Uint8Array;
  trackEvent: TrackEventFn;
}) {
  const passwordKey = await generatePasswordKey({
    password,
    salt,
    iterations,
    trackEvent,
  });

  const decryptedPrivateKey = await AESDecrypt(
    base64ToBytes(privateKey),
    passwordKey,
    privateIv,
  );

  return textDecoder.decode(decryptedPrivateKey);
}

export async function generateBundledKeys({
  password,
  trackEvent,
}: {
  password: string;
  trackEvent: TrackEventFn;
}) {
  const salt = await generateSalt();
  const privateIv = await generateSalt();

  const { privateKey, publicKey } = await generateRSAKeysJWKWithoutLeading0();

  let privateKey25K;
  let privateKey50K;

  if (flagPrivateKey50K) {
    privateKey50K = await encryptPrivateKeyIterations({
      password,
      salt,
      privateIv,
      privateKey,
      iterations: 50000,
      trackEvent,
    });
  } else {
    privateKey25K = await encryptPrivateKeyIterations({
      password,
      salt,
      privateIv,
      privateKey,
      iterations: 25000,
      trackEvent,
    });
  }

  return {
    salt: bytesToBase64(salt),
    public_key: JSON.stringify(publicKey),
    private_iv: bytesToBase64(privateIv),
    private_key_twenty_fivek: privateKey25K,
    private_key_fiftyk: privateKey50K,
  };
}

export async function decryptPrivateKey({
  cryptoKey: {
    private_iv,
    private_key_fiftyk,
    private_key_fivek,
    private_key_twenty_fivek,
    salt,
  },
  password,
  trackEvent,
}: {
  cryptoKey: RCryptoKey;
  password: string;
  trackEvent: TrackEventFn;
}): Promise<[string, string | undefined]> {
  // Use the highest number of iterations first.
  if (flagPrivateKey50K && private_key_fiftyk) {
    const decrypted50k = await decryptPrivateKeyIterations({
      password,
      salt: base64ToBytes(salt),
      privateIv: base64ToBytes(private_iv),
      privateKey: private_key_fiftyk,
      iterations: 50000,
      trackEvent,
    });
    return [decrypted50k, "50k"];
  }

  if (private_key_twenty_fivek) {
    const decrypted25k = await decryptPrivateKeyIterations({
      password,
      salt: base64ToBytes(salt),
      privateIv: base64ToBytes(private_iv),
      privateKey: private_key_twenty_fivek,
      iterations: 25000,
      trackEvent,
    });
    return [decrypted25k, "25k"];
  }

  if (private_key_fivek) {
    const decrypted5k = await decryptPrivateKeyIterations({
      password,
      salt: base64ToBytes(salt),
      privateIv: base64ToBytes(private_iv),
      privateKey: private_key_fivek,
      iterations: 5000,
      trackEvent,
    });
    return [decrypted5k, "5k"];
  }

  return ["", undefined];
}

export async function encryptFile(fileContent: ToType, sessionKey: CryptoKey) {
  const iv = await generateSalt();
  const encrypt = await AESEncrypt(fileContent, sessionKey, iv);

  return {
    encryptedFile: bytesToBase64(encrypt),
    iv: bytesToBase64(iv),
  };
}

export async function decryptFile(
  message: string,
  sessionKey: CryptoKey,
  iv: string,
) {
  return AESDecrypt(base64ToBytes(message), sessionKey, base64ToBytes(iv));
}

const generateKeyPair = async (
  accountId: number,
  password: string,
  last_password_change: number | null,
  dependencies: {
    api: ToType;
    token: string;
    trackEvent: TrackEventFn;
  },
) => {
  const keys = await generateBundledKeys({
    password,
    trackEvent: dependencies.trackEvent,
  });

  return dependencies.api.crypto.postKeys(
    accountId,
    Object.assign({}, keys, { last_password_change }),
    dependencies.token,
  );
};

async function backfillEncryptedPrivateKey({
  accountId,
  cryptoKeys,
  dependencies,
  iterations,
  password,
  privateKey,
}: {
  accountId: number;
  cryptoKeys: RCryptoKey;
  dependencies: {
    api: ToType;
    token: string;
    trackEvent: TrackEventFn;
  };
  iterations: number;
  password: string;
  privateKey: string;
}) {
  const { trackEvent } = dependencies;
  const startTime = new Date();
  const salt = base64ToBytes(cryptoKeys.salt);
  const privateIv = base64ToBytes(cryptoKeys.private_iv);

  const encyptedPrivateKey = await encryptPrivateKeyIterationsString({
    password,
    salt,
    privateIv,
    privateKey,
    iterations,
    trackEvent,
  });

  trackEvent({
    name: `Private key ${iterations / 1000}k generation`,
    status: "success",
    account_id: accountId,
    duration: differenceInMilliseconds(new Date(), startTime),
  });

  if (iterations == 5000)
    dependencies.api.crypto.updateEncryptedPrivateKey(
      accountId,
      { private_key_fivek: encyptedPrivateKey },
      dependencies.token,
    );

  if (iterations == 25000)
    dependencies.api.crypto.updateEncryptedPrivateKey(
      accountId,
      { private_key_twenty_fivek: encyptedPrivateKey },
      dependencies.token,
    );

  if (iterations == 50000)
    dependencies.api.crypto.updateEncryptedPrivateKey(
      accountId,
      { private_key_fiftyk: encyptedPrivateKey },
      dependencies.token,
    );
}

// Function to get or create the private key.
// Also handles the migration to a higher number of iterations
export async function getPrivateKey(
  account: Account,
  password: string,
  dependencies: {
    api: ToType;
    token: string;
    trackEvent: TrackEventFn;
  },
): Promise<[string | null, { decrypted?: string; generated?: string }]> {
  const { trackEvent } = dependencies;
  let generated, decrypted;
  const accountId = account.id;
  let cryptoKeys = account.crypto_keys;
  if (!cryptoKeys) {
    generated = "25k";

    try {
      cryptoKeys = await generateKeyPair(
        accountId,
        password,
        null,
        dependencies,
      );

      trackEvent({
        name: TRACK_EVENTS.GENERATING_PUBLIC_PRIVATE_KEYS,
        status: "success",
        account_id: accountId,
      });
    } catch (err) {
      trackEvent({
        name: TRACK_EVENTS.GENERATING_PUBLIC_PRIVATE_KEYS,
        status: "error",
        reason: err,
        account_id: accountId,
      });

      return [null, { generated, decrypted }];
    }
  }

  try {
    const decryptStartTime = new Date();

    const [privateKey, iterations] = await decryptPrivateKey({
      password,
      cryptoKey: cryptoKeys as RCryptoKey,
      trackEvent,
    });

    decrypted = iterations;
    const decryptionDuration = differenceInMilliseconds(
      new Date(),
      decryptStartTime,
    );

    if (privateKey) {
      if (flagPrivateKey50K) {
        if (cryptoKeys && !cryptoKeys.private_key_fiftyk) {
          generated = "50k";

          await backfillEncryptedPrivateKey({
            cryptoKeys,
            accountId,
            privateKey,
            password,
            iterations: 50000,
            dependencies,
          });
        }
      } else {
        if (cryptoKeys && !cryptoKeys.private_key_twenty_fivek) {
          generated = "25k";

          await backfillEncryptedPrivateKey({
            cryptoKeys,
            accountId,
            privateKey,
            password,
            iterations: 25000,
            dependencies,
          });
        }
      }
    }

    if (privateKey) {
      const extract = JSON.parse(privateKey);

      if (!extract?.d) {
        trackEvent({
          name: TRACK_EVENTS.PRIVATE_KEY_COULD_NOT_PARSE,
          account_id: accountId,
        });
      } else if (!checkKey(extract)) {
        trackEvent({
          name: TRACK_EVENTS.PRIVATE_KEY_D_LEADING_0,
          account_id: accountId,
        });
      }

      // test key is valid (e.g. no leading zeros)
      await importPrivateKey(privateKey);

      trackEvent({
        name: TRACK_EVENTS.PRIVATE_KEY_DECRYPTION,
        status: "success",
        account_id: accountId,
        duration: decryptionDuration,
      });
    }

    return [privateKey, { generated, decrypted }];
  } catch (err) {
    trackEvent({
      name: TRACK_EVENTS.PRIVATE_KEY_DECRYPTION,
      status: "error",
      account_id: accountId,
      reason: err,
    });

    console.error("Could not decrypt private key", err);
  }

  return [null, { generated, decrypted }];
}
