import { useInjection } from 'inversify-react';
import { useCallback, useState } from 'react';
import { Button, Form, Modal, Stack } from 'react-bootstrap';
import { useMutation, useQueryClient } from 'react-query';

import { useAddCredentials, RunButton } from '../components';
import {
   ConnectionAccessType,
   QueryRunOptions,
   QueryVersion,
   QueryLog,
   StepType,
   SystemQueryRun,
   QueryStep,
   QueryLogContext,
   DataConnection,
   logsFromQueryReturn,
   SchemaOverride,
} from '../entities';

import { QueryKey, QueryKeyType, DBMS } from '../enums';
import { QueryReturn } from '../interfaces';
import { QueryService, QueryLogService, isDesktop, DesktopQueryService } from '../services';
import { TYPES } from '../types';
import { handleError, queryIsDangerous } from '../utilities';
import { getQueryLogQueryKey, useFetchListWorkspaceConnections } from './';
import { useCurrentQuery } from '.';

import type { LocalCredentialService } from '../services/LocalCredentialService';
import type { DataConnectionService } from '../services/DataConnectionService';
import type { DataCredentialsModalProps } from '../components/DataCredentialsModal';

export type RunStatus = {
   isRunning?: boolean;
   results?: QueryReturn[][];
};

export const checkSnowflakeOAuthTokenFreshnessFactory =
   ({ localCredentialService }: { localCredentialService: LocalCredentialService }) =>
   async (dataConnectionId: number) => {
      const credentials = await localCredentialService.get(dataConnectionId);

      const expiresAt = credentials?.snowflakeOAuth?.expiresAt;

      return expiresAt !== undefined && new Date(expiresAt).getTime() > Date.now();
   };

const handleMissingCredentialsFactory =
   <T,>({
      addCredentials,
      dataConnectionId,
      runThunk,
      workspaceId,
   }: {
      addCredentials: (params: Omit<DataCredentialsModalProps, 'show'>) => void;
      dataConnectionId: number;
      runThunk: () => Promise<T>;
      workspaceId: number;
   }) =>
   () =>
      new Promise<T>((resolve, reject) => {
         addCredentials({
            connectionId: dataConnectionId,
            handleClose: async (success: boolean) => {
               if (!success) {
                  return reject(new Error('Please add credentials to run this query'));
               }

               try {
                  const result = await runThunk();
                  resolve(result);
               } catch (err) {
                  reject(err);
               }
            },
            workspaceId,
         });
      });

const getStepsToRun = (steps: QueryStep[], options: QueryRunOptions) =>
   steps
      .filter(
         (step) =>
            (options.step === undefined || step.order === options.step) &&
            (options.stopAfterStep === undefined || step.order <= options.stopAfterStep)
      )
      .map((step) => ({
         ...step,
         queryText:
            typeof options.queryTextOverride === 'string' &&
            typeof options.step === 'number' &&
            options.step === step.order
               ? options.queryTextOverride
               : step.queryText,
      }));

export type RunOptions = {
   overrideSchema?: SchemaOverride;
   queryTextOverride?: string;
   queryVersion?: QueryVersion;
   step?: number;
   stopAfterStep?: number;
   suppressDangerousWarning?: boolean;
};

const getRunStatusKey = (queryVersion?: QueryVersion) =>
   queryVersion?.query?.id ?? queryVersion?.id;

export const useRunQueries = ({
   exploreTabId,
   onStatusChange,
}: {
   exploreTabId?: number;
   onStatusChange?: (queryVersion: QueryVersion, status: RunStatus) => void;
} = {}) => {
   const [_runStatus, _setRunStatus] = useState<Record<number, RunStatus>>({});
   const { paramOverrides } = useCurrentQuery();

   const runStatus = useCallback(
      (queryVersion: QueryVersion) => _runStatus[getRunStatusKey(queryVersion) ?? -1],
      [_runStatus]
   );
   const setRunStatus = useCallback(
      (queryVersion: QueryVersion, status: RunStatus) => {
         const key = getRunStatusKey(queryVersion);
         if (!key) return;
         _setRunStatus((runStatus) => ({
            ...runStatus,
            [key]: status,
         }));
         onStatusChange && onStatusChange(queryVersion, status);
      },
      [onStatusChange]
   );
   const [ignoreDangerous, setIgnoreDangerous] = useState(false);
   const [promptDangerousQuery, setPromptDangerousQuery] = useState<RunOptions>();

   const queryService = useInjection<QueryService>(TYPES.queryService);
   const queryLogService = useInjection<QueryLogService>(TYPES.querylogService);
   const desktopQueryService = useInjection<DesktopQueryService>(TYPES.desktopQueryService);
   const queryClient = useQueryClient();
   const refresh = (count: number) => {
      queryClient.invalidateQueries([QueryKey.QueryLog]);
      setTimeout(() => {
         count--;
         if (count > 0) {
            refresh(count);
         }
      }, 1000);
   };
   const runQuery = useMutation({
      onSuccess: () => {
         refresh(3);
      },
      mutationFn: async ({
         queryVersion,
         options,
      }: {
         options: QueryRunOptions;
         queryVersion: QueryVersion;
      }) => {
         options = {
            ...options,
            params: paramOverrides,
         };

         if (!isDesktop()) {
            return queryService.runQuery(queryVersion, options);
         }

         const stepsToRun = getStepsToRun(queryVersion.steps, options);

         if (options.overrideSchema) {
            stepsToRun.forEach((step) => {
               if (step.type === StepType.DATA_CONNECTION) {
                  step.dataConnection = options.overrideSchema?.dataConnection;
                  step.dataConnectionId = options.overrideSchema?.dataConnection.id;
                  step.schemaName = options.overrideSchema?.schemaName;
               }
            });
         }

         const [localSteps] = stepsToRun.reduce(
            (acc, step) => {
               // Even if the other conditions for local are met, if a Snowflake connection is using
               // OAuth, then we force it to run through the API.
               const isLocal =
                  step.type === StepType.DATA_CONNECTION &&
                  step.dataConnection?.connectionAccessType === ConnectionAccessType.INDIVIDUAL &&
                  (step.dataConnection?.dbms !== DBMS.Snowflake ||
                     step.dataConnection?.metadataPublic?.oauth === undefined);

               acc[isLocal ? 0 : 1].push(step);

               return acc;
            },
            [[], []] as [QueryStep[], QueryStep[]]
         );

         // Local-only?
         if (localSteps.length === stepsToRun.length) {
            const results = await desktopQueryService.runQuery(
               stepsToRun as Array<{
                  dataConnection: DataConnection;
                  queryText: string;
                  schemaName?: string | null;
               }>,
               Object.fromEntries(
                  (queryVersion.parameters ?? []).map((p) => [
                     p.name,
                     (options.params ?? {})[p.name] ?? p.defaultValue,
                  ])
               )
            );

            try {
               await queryLogService.post(
                  logsFromQueryReturn({
                     steps: stepsToRun,
                     queryVersion,
                     results,
                     exploreTabId: options.exploreTabId,
                  })
               );
            } catch (err) {
               // Don't fail the query if we fail to log (e.g. due to network issues)
               console.error('Error logging query', err);
            }

            return results;
         }

         // Cloud-only or Mixed
         // TODO: We used to only include credential headers if local steps are present. But the
         // headers also need to be included if there are Snowflake OAuth steps, so we've switched
         // to always including them.
         return queryService.runQuery(queryVersion, options, true);
      },
   });

   const fetchWorkspaceConnections = useFetchListWorkspaceConnections();
   const addCredentials = useAddCredentials();
   const localCredentialService = useInjection<LocalCredentialService>(
      TYPES.localCredentialService
   );
   const checkSnowflakeOAuthTokenFreshness = checkSnowflakeOAuthTokenFreshnessFactory({
      localCredentialService,
   });

   const run = useCallback(
      async (runOptions: RunOptions = {}): Promise<void> => {
         const queryVersion = runOptions.queryVersion;

         if (!queryVersion?.id) {
            throw new Error('Query version must be saved before running');
         }

         if (!queryVersion.query?.workspaceId) {
            throw new Error('Query must have a workspace');
         }

         if (runOptions.queryTextOverride && !runOptions.step) {
            throw new Error('Must specify a step when overriding the query');
         }

         if (runStatus(queryVersion)?.isRunning) {
            return;
         }

         try {
            if (!ignoreDangerous && !runOptions.suppressDangerousWarning) {
               const hasDangerous = getStepsToRun(queryVersion.steps, runOptions).find(
                  (step) => step.type !== StepType.PYTHON && queryIsDangerous(step.queryText ?? '')
               );
               if (hasDangerous) {
                  setPromptDangerousQuery(runOptions);
                  return;
               }
            }

            const workspaceSchemaConnections = await fetchWorkspaceConnections({
               workspaceId: queryVersion.query.workspaceId,
               includeConnectionDetails: true,
            });

            const stepsWithMissingCredentials = (
               await Promise.all(
                  getStepsToRun(queryVersion.steps, runOptions).map(async (step) => {
                     if (step.type !== StepType.DATA_CONNECTION || !step.dataConnectionId) {
                        return [];
                     }

                     const dataConnectionId = step.dataConnectionId;

                     const connection = workspaceSchemaConnections.find(
                        (wdc) => wdc.dataConnection?.id === dataConnectionId
                     )?.dataConnection;

                     if (
                        !connection?.id ||
                        connection.connectionAccessType !== ConnectionAccessType.INDIVIDUAL
                     ) {
                        return [];
                     }

                     if (
                        connection.dbms !== DBMS.Snowflake ||
                        connection.metadataPublic?.oauth === undefined
                     ) {
                        const hasCredentials = await localCredentialService.has(connection.id);

                        return hasCredentials ? [] : [{ ...step, dataConnectionId }];
                     }

                     // For Snowflake connections that are using OAuth, treat expired refresh tokens
                     // as missing credentials

                     const isOAuthTokenFresh = await checkSnowflakeOAuthTokenFreshness(
                        dataConnectionId
                     );

                     return isOAuthTokenFresh ? [] : [{ ...step, dataConnectionId }];
                  })
               )
            ).flat();

            if (stepsWithMissingCredentials.length > 0) {
               return handleMissingCredentialsFactory({
                  addCredentials,
                  dataConnectionId: stepsWithMissingCredentials[0].dataConnectionId,
                  runThunk: () => run(runOptions),
                  workspaceId: queryVersion.query.workspaceId,
               })();
            }

            setRunStatus(queryVersion, {
               isRunning: true,
            });

            const results = await runQuery.mutateAsync({
               queryVersion,
               options: {
                  exploreTabId,
                  ...runOptions,
               },
            });

            if (results !== undefined) {
               setRunStatus(queryVersion, {
                  results: results ?? undefined,
               });
               queryClient.invalidateQueries(getQueryLogQueryKey({ type: QueryKeyType.LIST }));
            } else {
               setRunStatus(queryVersion, {});
            }
         } catch (err) {
            handleError(err);
            setRunStatus(queryVersion, {});
         }
      },
      [
         addCredentials,
         queryClient,
         fetchWorkspaceConnections,
         runQuery,
         exploreTabId,
         ignoreDangerous,
         runStatus,
         setRunStatus,
         localCredentialService,
         checkSnowflakeOAuthTokenFreshness,
      ]
   );

   const runButton = (queryVersion: QueryVersion) => (
      <RunButton
         disabled={!queryVersion?.steps?.[0]?.queryText}
         key="run"
         onClick={() => run({ queryVersion })}
         running={!!queryVersion.id && runStatus(queryVersion)?.isRunning}
      >
         Run
      </RunButton>
   );

   const modals = (
      <>
         <Modal show={!!promptDangerousQuery}>
            <Modal.Header>
               <Modal.Title className="fs-14p">Dangerous Query</Modal.Title>
            </Modal.Header>
            <Modal.Body>
               <div>
                  This query will modify the database.
                  <br />
                  Do you want to continue?
               </div>
            </Modal.Body>
            <Modal.Footer>
               <Form>
                  <Form.Check
                     checked={ignoreDangerous}
                     label="Don't warn again in this tab"
                     onChange={(event) => setIgnoreDangerous(event.target.checked)}
                     type="checkbox"
                  />
                  <Stack className="justify-content-end" direction="horizontal" gap={2}>
                     <Button
                        className={'py-1 btn-secondary'}
                        onClick={() => setPromptDangerousQuery(undefined)}
                     >
                        Cancel
                     </Button>
                     <RunButton
                        onClick={() => {
                           if (!promptDangerousQuery) return;
                           run({
                              ...promptDangerousQuery,
                              suppressDangerousWarning: true,
                           });
                           setPromptDangerousQuery(undefined);
                        }}
                        running={
                           !!promptDangerousQuery?.queryVersion?.id &&
                           runStatus(promptDangerousQuery.queryVersion)?.isRunning
                        }
                     />
                  </Stack>
               </Form>
            </Modal.Footer>
         </Modal>
      </>
   );

   return { run, runButton, runStatus, modals, paramOverrides };
};

// The `queryVersion` passed to `useRunQuery` is only used to lookup the run state.
// runOptions.queryVersion` controls what queries are actually run.
export const useRunQuery = (
   queryVersion: QueryVersion | undefined,
   {
      exploreTabId,
   }: {
      exploreTabId?: number;
   } = {}
) => {
   const { run, runStatus, modals } = useRunQueries({ exploreTabId });

   const runThis = useCallback(
      (runOptions: RunOptions & { queryVersion: QueryVersion }) => {
         if (getRunStatusKey(runOptions.queryVersion) !== getRunStatusKey(queryVersion)) {
            throw new Error('Bad implementation');
         }

         return run(runOptions);
      },
      [queryVersion, run]
   );

   return {
      run: runThis,
      modals,
      ...(queryVersion ? runStatus(queryVersion) : { isRunning: false, results: [] }),
   };
};

export const useRunSystemQuery = () => {
   const dataConnectionService = useInjection<DataConnectionService>(TYPES.dataConnectionService);
   const queryService = useInjection<QueryService>(TYPES.queryService);
   const queryLogService = useInjection<QueryLogService>(TYPES.querylogService);
   const localCredentialService = useInjection<LocalCredentialService>(
      TYPES.localCredentialService
   );
   const desktopQueryService = useInjection<DesktopQueryService>(TYPES.desktopQueryService);
   const [isRunning, setIsRunning] = useState(false);
   const addCredentials = useAddCredentials();

   const checkSnowflakeOAuthTokenFreshness = checkSnowflakeOAuthTokenFreshnessFactory({
      localCredentialService,
   });

   const run = useCallback(
      async ({
         dataConnection,
         query,
         exploreTabId,
         workspaceId,
         updateSchema,
         overrideSchema,
      }: Omit<SystemQueryRun, 'dataConnectionId'> & {
         dataConnection: DataConnection | number;
      }): Promise<QueryReturn> => {
         if (!workspaceId) throw new Error('Workspace ID is required');

         // You can pass in a DataConnection object or just the id.
         if (typeof dataConnection === 'number') {
            const result = await dataConnectionService.get(dataConnection);

            if (!result) {
               throw new Error('DataConnection ID is required');
            }

            dataConnection = result;
         }

         if (!dataConnection.id) {
            throw new Error('DataConnection ID is required');
         }

         const handleMissingCredentials = handleMissingCredentialsFactory({
            addCredentials,
            dataConnectionId: dataConnection.id,
            runThunk: () =>
               run({
                  dataConnection,
                  query,
                  exploreTabId,
                  workspaceId,
                  updateSchema,
                  overrideSchema,
               }),
            workspaceId,
         });

         if (
            !isDesktop() ||
            dataConnection.connectionAccessType !== ConnectionAccessType.INDIVIDUAL
         ) {
            if (dataConnection.connectionAccessType === ConnectionAccessType.INDIVIDUAL) {
               if (
                  dataConnection.dbms !== DBMS.Snowflake ||
                  dataConnection.metadataPublic?.oauth === undefined
               ) {
                  const hasCredentials = await localCredentialService.has(dataConnection.id);

                  if (!hasCredentials) {
                     return handleMissingCredentials();
                  }
               } else {
                  // For Snowflake connections that are using OAuth, treat expired refresh tokens
                  // as missing credentials

                  const isOAuthTokenFresh = await checkSnowflakeOAuthTokenFreshness(
                     dataConnection.id
                  );

                  if (!isOAuthTokenFresh) {
                     return handleMissingCredentials();
                  }
               }
            }

            try {
               setIsRunning(true);
               const result = await queryService.runSystemQuery(
                  {
                     dataConnectionId: dataConnection.id!,
                     query,
                     exploreTabId,
                     workspaceId,
                     updateSchema,
                     overrideSchema,
                  },
                  !isDesktop() &&
                     dataConnection.connectionAccessType === ConnectionAccessType.INDIVIDUAL
               );

               return result;
            } finally {
               setIsRunning(false);
            }
         }

         // -- On desktop + ConnectionAccessType.INDIVIDUAL

         if (
            dataConnection.dbms !== DBMS.Snowflake ||
            dataConnection.metadataPublic?.oauth === undefined
         ) {
            const hasCredentials = await localCredentialService.has(dataConnection.id);

            if (!hasCredentials) {
               return handleMissingCredentials();
            }
         } else {
            // For Snowflake connections that are using OAuth, treat expired refresh tokens
            // as missing credentials

            const isOAuthTokenFresh = await checkSnowflakeOAuthTokenFreshness(dataConnection.id);

            if (!isOAuthTokenFresh) {
               return handleMissingCredentials();
            }
         }

         try {
            setIsRunning(true);

            // If `query` has multiple statements, then we will only return the result of the first
            // statement. This behaviour is consistent with the API's runSystemQuery implementation.
            const [[result]] = await desktopQueryService.runQuery({
               dataConnection: overrideSchema?.dataConnection ?? dataConnection,
               queryText: query,
               schemaName: overrideSchema?.schemaName,
            });

            try {
               const log: QueryLog = {
                  context: QueryLogContext.SYSTEM,
                  dataConnectionId: dataConnection.id,
                  queryText: query,
                  runtime: result.runtime,
                  step: 1,
                  stepType: StepType.DATA_CONNECTION,
                  workspaceId,
                  exploreTabId,
               };
               await queryLogService.post([log]);
            } catch (err) {
               // Don't fail the query if we fail to log (e.g. due to network issues)
               console.error('Error logging system query', err);
            }

            if (result.error) {
               throw result.error instanceof Error ? result.error : new Error(result.error);
            }

            if (updateSchema?.value) {
               try {
                  await dataConnectionService.updateSchemaDesktop(
                     dataConnection,
                     updateSchema.target
                  );
               } catch (err) {
                  // Don't fail the query if we fail to update the schema
                  console.error('Error updating schema', err);
               }
            }

            return result;
         } finally {
            setIsRunning(false);
         }
      },
      [
         addCredentials,
         queryService,
         queryLogService,
         dataConnectionService,
         desktopQueryService,
         localCredentialService,
         checkSnowflakeOAuthTokenFreshness,
      ]
   );

   return { run, isRunning };
};
