import {
  Array,
  Boolean,
  Literal,
  Null,
  Number,
  Optional,
  Record as RRecord,
  Static,
  String,
  Union,
} from "runtypes";

import { getNormalEnum } from "../runtypeEnums";

export const OpenAiChatModelLiteral = Union(
  Literal("GPT_4"),
  Literal("GPT_4_FUNCTIONS"),
  Literal("GPT_4_32K"),
  Literal("GPT_4_TURBO"),
  Literal("GPT_4_OMNI"),
  Literal("GPT_4_OMNI_HEX_FINETUNED"),
  Literal("GPT_4_OMNI_MINI"),
  Literal("GPT_4_TURBO_0409"),
  Literal("GPT_35_TURBO"),
  Literal("GPT_35_TURBO_FUNCTIONS"),
);
export type OpenAiChatModel = Static<typeof OpenAiChatModelLiteral>;
export const OpenAiChatModel = getNormalEnum(OpenAiChatModelLiteral);

export const OpenAiChatModelNameMap: Record<OpenAiChatModel, string> = {
  [OpenAiChatModel.GPT_4]: "gpt-4-0314",
  [OpenAiChatModel.GPT_4_FUNCTIONS]: "gpt-4-0613",
  [OpenAiChatModel.GPT_4_TURBO]: "gpt-4-1106-preview",
  [OpenAiChatModel.GPT_4_TURBO_0409]: "gpt-4-turbo-2024-04-09",
  [OpenAiChatModel.GPT_4_32K]: "gpt-4-32k-0314",
  [OpenAiChatModel.GPT_4_OMNI]: "gpt-4o-2024-05-13",
  [OpenAiChatModel.GPT_4_OMNI_HEX_FINETUNED]:
    "ft:gpt-4o-2024-08-06:hex:generate-sql-v5:AdWbfL2z",
  [OpenAiChatModel.GPT_4_OMNI_MINI]: "gpt-4o-mini-2024-07-18",
  [OpenAiChatModel.GPT_35_TURBO]: "gpt-3.5-turbo-0301",
  [OpenAiChatModel.GPT_35_TURBO_FUNCTIONS]: "gpt-3.5-turbo-1106",
};

// OpenAI documentation of how they count tokens seems to be kinda off, so
// for now we'll just subtract 10 from the max tokens to be safe.
export const OPENAI_CHAT_MODEL_DETAILS: Record<
  OpenAiChatModel,
  | { maxTokens: number; tokensPerMessage: number }
  | {
      maxInputTokens: number;
      maxOutputTokens: number;
      tokensPerMessage: number;
    }
> = {
  [OpenAiChatModel.GPT_4]: { maxTokens: 8192 - 10, tokensPerMessage: 3 },
  [OpenAiChatModel.GPT_4_FUNCTIONS]: {
    maxTokens: 8192 - 10,
    tokensPerMessage: 3,
  },
  [OpenAiChatModel.GPT_4_TURBO]: {
    maxInputTokens: 128000 - 4096 - 10,
    maxOutputTokens: 4096,
    tokensPerMessage: 3,
  },
  [OpenAiChatModel.GPT_4_TURBO_0409]: {
    maxInputTokens: 128000 - 4096 - 10,
    maxOutputTokens: 4096,
    tokensPerMessage: 3,
  },
  [OpenAiChatModel.GPT_4_OMNI]: {
    maxInputTokens: 128000 - 4096 - 10,
    maxOutputTokens: 4096,
    tokensPerMessage: 3,
  },
  [OpenAiChatModel.GPT_4_OMNI_MINI]: {
    maxInputTokens: 128000 - 16384 - 10,
    maxOutputTokens: 16384,
    tokensPerMessage: 3,
  },
  [OpenAiChatModel.GPT_4_OMNI_HEX_FINETUNED]: {
    maxInputTokens: 128000 - 16384 - 10,
    maxOutputTokens: 16384,
    tokensPerMessage: 3,
  },
  [OpenAiChatModel.GPT_4_32K]: { maxTokens: 32768 - 10, tokensPerMessage: 3 },
  [OpenAiChatModel.GPT_35_TURBO]: { maxTokens: 4096 - 10, tokensPerMessage: 4 },
  [OpenAiChatModel.GPT_35_TURBO_FUNCTIONS]: {
    maxTokens: 4096 - 10,
    tokensPerMessage: 4,
  },
};

export const CustomOpenAiModelConfig = RRecord({
  model: String,
  maxInputTokens: Number,
  maxOutputTokens: Number,
});
export type CustomOpenAiModelConfig = Static<typeof CustomOpenAiModelConfig>;

export const OpenAiModelParams = RRecord({
  frequencyPenalty: Number,
  n: Number,
  presencePenalty: Number,
  stopWords: Array(String),
  temperature: Number,
  topP: Number,
  model: Union(OpenAiChatModelLiteral, CustomOpenAiModelConfig),
});
export type OpenAiModelParams = Static<typeof OpenAiModelParams>;

// Based on `ChatCompletionRequestMessageRoleEnum` from openai package
export const OpenAiChatMessageRoleLiteral = Union(
  Literal("system"),
  Literal("user"),
  Literal("assistant"),
);
export type OpenAiChatMessageRole = Static<typeof OpenAiChatMessageRoleLiteral>;
export const OpenAiChatMessageRole = getNormalEnum(
  OpenAiChatMessageRoleLiteral,
);

// Based on `ChatCompletionRequestMessage` from openai package
export const OpenAiChatMessage = RRecord({
  role: OpenAiChatMessageRoleLiteral,
  content: String.Or(Null),
  functionCall: Optional(RRecord({ arguments: String, name: String })),
});
export type OpenAiChatMessage = Static<typeof OpenAiChatMessage>;

export const OpenAIConformerResult = RRecord({
  messages: Array(OpenAiChatMessage),
  totalContextTokens: Number,
  usedTokenCountApproximation: Boolean,
});
export type OpenAIConformerResult = Static<typeof OpenAIConformerResult>;
