import { DBMS_ANALYZE_SUPPORTED, DBMS_EXPLAIN_SUPPORTED, RunaMode } from '@runql/util';
import { useInjection } from 'inversify-react';
import React, { createContext, useCallback, useState } from 'react';
import { useMutation } from 'react-query';
import { DataConnection, StepType } from '../entities';
import { AskRunaRequest, AskRunaResponse, AskService, isDesktop, RunaStep } from '../services';
import { TYPES } from '../types';
import { buildExplainQuery, extractExplainResult, formatQuery } from '../utilities';
import { useRunSystemQuery } from './run';

const useAskService = (): AskService => {
   return useInjection<AskService>(TYPES.askService);
};

type AskRunaContextType = {
   askRuna: ({
      clearConversation,
      dataConnections,
      excludeQuery,
      mode,
      prompt,
      steps,
      workspaceId,
   }: {
      clearConversation?: boolean;
      dataConnections?: Record<number, DataConnection>;
      excludeQuery?: boolean;
      mode: RunaMode;
      prompt?: string;
      steps: RunaStep[];
      workspaceId?: number;
   }) => Promise<AskRunaResponse[] | undefined>;
   clearConversation: () => void;
   conversation: AskRunaResponse[];
   isLoading: boolean;
};

const AskRunaContext = createContext<AskRunaContextType>({
   async askRuna() {
      return undefined;
   },
   clearConversation() {
      return undefined;
   },
   isLoading: false,
   conversation: [],
});

export function AskRunaProvider({ children }: { children: React.ReactNode }) {
   const askService = useAskService();
   const askRunaRequest = useMutation({
      mutationFn: async (req: AskRunaRequest) => {
         return askService.askRuna(req);
      },
   });
   const [conversation, setConversation] = useState<AskRunaResponse[]>([]);
   const { run } = useRunSystemQuery();

   const askRuna = useCallback(
      async ({
         clearConversation,
         dataConnections,
         excludeQuery,
         mode,
         prompt,
         steps,
         workspaceId,
      }: {
         clearConversation?: boolean;
         dataConnections?: Record<number, DataConnection>;
         excludeQuery?: boolean;
         mode: RunaMode;
         prompt?: string;
         steps: RunaStep[];
         workspaceId?: number;
      }) => {
         if (clearConversation) {
            setConversation([]);
         }

         // If user is on desktop and the mode requires explain plans, fetch them before the LLM call.
         if (
            isDesktop() &&
            (mode === RunaMode.Optimize ||
               mode === RunaMode.CompareExecPlans ||
               mode === RunaMode.AnalyzeExecPlans)
         ) {
            for (const step of steps) {
               if (step.stepType !== StepType.DATA_CONNECTION) continue;
               const connection = dataConnections?.[step.dataConnectionId!];
               if (!connection) {
                  console.warn('Missing connection');
                  continue;
               }
               const promises: Promise<unknown>[] = [];
               if (connection && connection.dbms) {
                  if (
                     step.query &&
                     DBMS_EXPLAIN_SUPPORTED.includes(connection.dbms) &&
                     (mode === RunaMode.Optimize || mode === RunaMode.AnalyzeExecPlans)
                  ) {
                     promises.push(
                        run({
                           dataConnection: connection,
                           query: buildExplainQuery(
                              step.query,
                              connection.dbms,
                              mode === RunaMode.AnalyzeExecPlans
                           ),
                           workspaceId,
                           overrideSchema: {
                              dataConnection: connection,
                              schemaName: step.defaultSchema,
                           },
                        }).then((res) => {
                           step.execPlan = extractExplainResult(connection.dbms!, res);
                        })
                     );
                  }

                  if (
                     step.suggestedQuery &&
                     ((mode === RunaMode.CompareExecPlans &&
                        DBMS_EXPLAIN_SUPPORTED.includes(connection.dbms)) ||
                        (mode === RunaMode.AnalyzeExecPlans &&
                           DBMS_ANALYZE_SUPPORTED.includes(connection.dbms)))
                  ) {
                     promises.push(
                        run({
                           dataConnection: connection,
                           query: buildExplainQuery(
                              step.suggestedQuery,
                              connection.dbms,
                              mode === RunaMode.AnalyzeExecPlans
                           ),
                           workspaceId,
                           overrideSchema: {
                              dataConnection: connection,
                              schemaName: step.defaultSchema,
                           },
                        }).then((res) => {
                           step.suggestedExecPlan = extractExplainResult(connection.dbms!, res);
                        })
                     );
                  }
               }
               try {
                  if (promises.length) await Promise.all(promises);
               } catch {
                  // Do not fail the request if explain plans fail
               }
            }
         }

         const response = await askRunaRequest.mutateAsync({
            client: isDesktop() ? 'desktop' : 'web',
            mode,
            prompt,
            previous: mode === RunaMode.Runa || clearConversation === false ? conversation : [],
            steps,
         });
         if (!response) return;
         if (excludeQuery) response.steps = [];
         else if (response?.steps?.length) {
            let steps = response.steps;
            steps = steps.map((step) => {
               try {
                  const formattedQuery = formatQuery(step.currentQuery || '');
                  step.currentQuery = formattedQuery;
                  return step;
               } catch (e) {
                  // ignore it
                  return step;
               }
            });
            response.steps = steps;
         }
         if (mode === RunaMode.Runa || clearConversation === false) {
            setConversation([...conversation, response]);
         } else {
            setConversation([response]);
         }

         return conversation;
      },
      [askRunaRequest, conversation, run]
   );

   const clearConversation = useCallback(() => {
      setConversation([]);
   }, []);

   return (
      <AskRunaContext.Provider
         value={{
            askRuna,
            clearConversation,
            conversation,
            isLoading: askRunaRequest.isLoading,
         }}
      >
         {children}
      </AskRunaContext.Provider>
   );
}

export function useAskRunaContext() {
   return React.useContext(AskRunaContext);
}
