# eagerpy.*

from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast, Union
from typing_extensions import Literal

from .types import Axes, AxisAxes, Shape, ShapeOrScalar

from .tensor import Tensor
from .tensor import TensorType
from .tensor import TensorOrScalar

newaxis = None
inf = float("inf")
nan = float("nan")


def clip(t: TensorType, min_: float, max_: float) -> TensorType:
    return t.clip(min_, max_)


def abs(t: TensorType) -> TensorType:
    return t.abs()


def sign(t: TensorType) -> TensorType:
    return t.sign()


def sqrt(t: TensorType) -> TensorType:
    return t.sqrt()


def square(t: TensorType) -> TensorType:
    return t.square()


def pow(t: TensorType, exponent: TensorOrScalar) -> TensorType:
    return t.pow(exponent)


def sin(t: TensorType) -> TensorType:
    return t.sin()


def cos(t: TensorType) -> TensorType:
    return t.cos()


def tan(t: TensorType) -> TensorType:
    return t.tan()


def arcsin(t: TensorType) -> TensorType:
    return t.arcsin()


def arccos(t: TensorType) -> TensorType:
    return t.arccos()


def arctan(t: TensorType) -> TensorType:
    return t.arctan()


def sinh(t: TensorType) -> TensorType:
    return t.sinh()


def cosh(t: TensorType) -> TensorType:
    return t.cosh()


def tanh(t: TensorType) -> TensorType:
    return t.tanh()


def arcsinh(t: TensorType) -> TensorType:
    return t.arcsinh()


def arccosh(t: TensorType) -> TensorType:
    return t.arccosh()


def arctanh(t: TensorType) -> TensorType:
    return t.arctanh()


def sum(
    t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
    return t.sum(axis=axis, keepdims=keepdims)


def prod(
    t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
    return t.prod(axis=axis, keepdims=keepdims)


def mean(
    t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
    return t.mean(axis=axis, keepdims=keepdims)


def min(
    t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
    return t.min(axis=axis, keepdims=keepdims)


def max(
    t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
    return t.max(axis=axis, keepdims=keepdims)


@overload
def minimum(x: TensorType, y: TensorOrScalar) -> TensorType:
    ...


@overload
def minimum(x: TensorOrScalar, y: TensorType) -> TensorType:
    ...


def minimum(x: TensorOrScalar, y: TensorOrScalar) -> Tensor:
    if not isinstance(x, Tensor):
        return cast(Tensor, y).minimum(x)
    return x.minimum(y)


@overload
def maximum(x: TensorType, y: TensorOrScalar) -> TensorType:
    ...


@overload
def maximum(x: TensorOrScalar, y: TensorType) -> TensorType:
    ...


def maximum(x: TensorOrScalar, y: TensorOrScalar) -> Tensor:
    if not isinstance(x, Tensor):
        return cast(Tensor, y).maximum(x)
    return x.maximum(y)


def argmin(t: TensorType, axis: Optional[int] = None) -> TensorType:
    return t.argmin(axis=axis)


def argmax(t: TensorType, axis: Optional[int] = None) -> TensorType:
    return t.argmax(axis=axis)


def argsort(t: TensorType, axis: int = -1) -> TensorType:
    return t.argsort(axis=axis)


def sort(t: TensorType, axis: int = -1) -> TensorType:
    return t.sort(axis=axis)


def topk(t: TensorType, k: int, sorted: bool = True) -> Tuple[TensorType, TensorType]:
    return t.topk(k, sorted=sorted)


def uniform(
    t: TensorType, shape: ShapeOrScalar, low: float = 0.0, high: float = 1.0
) -> TensorType:
    return t.uniform(shape, low=low, high=high)


def normal(
    t: TensorType, shape: ShapeOrScalar, mean: float = 0.0, stddev: float = 1.0
) -> TensorType:
    return t.normal(shape, mean=mean, stddev=stddev)


def ones(t: TensorType, shape: ShapeOrScalar) -> TensorType:
    return t.ones(shape)


def zeros(t: TensorType, shape: ShapeOrScalar) -> TensorType:
    return t.zeros(shape)


def ones_like(t: TensorType) -> TensorType:
    return t.ones_like()


def zeros_like(t: TensorType) -> TensorType:
    return t.zeros_like()


def full_like(t: TensorType, fill_value: float) -> TensorType:
    return t.full_like(fill_value)


def onehot_like(t: TensorType, indices: TensorType, *, value: float = 1) -> TensorType:
    return t.onehot_like(indices, value=value)


def from_numpy(t: TensorType, a: Any) -> TensorType:
    return t.from_numpy(a)


def concatenate(tensors: Sequence[TensorType], axis: int = 0) -> TensorType:
    t = tensors[0]
    return t._concatenate(tensors, axis=axis)


def transpose(t: TensorType, axes: Optional[Axes] = None) -> TensorType:
    return t.transpose(axes=axes)


@overload
def logical_and(x: TensorType, y: TensorOrScalar) -> TensorType:
    ...


@overload
def logical_and(x: TensorOrScalar, y: TensorType) -> TensorType:
    ...


def logical_and(x: TensorOrScalar, y: TensorOrScalar) -> Tensor:
    if not isinstance(x, Tensor):
        return cast(Tensor, y).logical_and(x)
    return x.logical_and(y)


@overload
def logical_or(x: TensorType, y: TensorOrScalar) -> TensorType:
    ...


@overload
def logical_or(x: TensorOrScalar, y: TensorType) -> TensorType:
    ...


def logical_or(x: TensorOrScalar, y: TensorOrScalar) -> Tensor:
    if not isinstance(x, Tensor):
        return cast(Tensor, y).logical_or(x)
    return x.logical_or(y)


def logical_not(t: TensorType) -> TensorType:
    return t.logical_not()


def exp(t: TensorType) -> TensorType:
    return t.exp()


def log(t: TensorType) -> TensorType:
    return t.log()


def log2(t: TensorType) -> TensorType:
    return t.log2()


def log10(t: TensorType) -> TensorType:
    return t.log10()


def log1p(t: TensorType) -> TensorType:
    return t.log1p()


def where(condition: TensorType, x: TensorOrScalar, y: TensorOrScalar) -> TensorType:
    return condition.where(x, y)


def tile(t: TensorType, multiples: Axes) -> TensorType:
    return t.tile(multiples)


def matmul(x: TensorType, y: TensorType) -> TensorType:
    return x.matmul(y)


def softmax(t: TensorType, axis: int = -1) -> TensorType:
    return t.softmax(axis=axis)


def log_softmax(t: TensorType, axis: int = -1) -> TensorType:
    return t.log_softmax(axis=axis)


def stack(tensors: Sequence[TensorType], axis: int = 0) -> TensorType:
    t = tensors[0]
    return t._stack(tensors, axis=axis)


def squeeze(t: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
    return t.squeeze(axis=axis)


def expand_dims(t: TensorType, axis: int) -> TensorType:
    return t.expand_dims(axis=axis)


def full(t: TensorType, shape: ShapeOrScalar, value: float) -> TensorType:
    return t.full(shape, value)


def index_update(t: TensorType, indices: Any, values: TensorOrScalar) -> TensorType:
    return t.index_update(indices, values)


def arange(
    t: TensorType, start: int, stop: Optional[int] = None, step: Optional[int] = None
) -> TensorType:
    return t.arange(start, stop, step)


def cumsum(t: TensorType, axis: Optional[int] = None) -> TensorType:
    return t.cumsum(axis=axis)


def flip(t: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
    return t.flip(axis=axis)


def meshgrid(
    t: TensorType, *tensors: TensorType, indexing: str = "xy"
) -> Tuple[TensorType, ...]:
    return t.meshgrid(*tensors, indexing=indexing)


def pad(
    t: TensorType,
    paddings: Tuple[Tuple[int, int], ...],
    mode: str = "constant",
    value: float = 0,
) -> TensorType:
    return t.pad(paddings, mode=mode, value=value)


def isnan(t: TensorType) -> TensorType:
    return t.isnan()


def isinf(t: TensorType) -> TensorType:
    return t.isinf()


def all(
    t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
    return t.all(axis=axis, keepdims=keepdims)


def any(
    t: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False
) -> TensorType:
    return t.any(axis=axis, keepdims=keepdims)


def crossentropy(logits: TensorType, labels: TensorType) -> TensorType:
    return logits.crossentropy(labels)


def slogdet(matrix: TensorType) -> Tuple[TensorType, TensorType]:
    return matrix.slogdet()


@overload
def value_and_grad_fn(
    t: TensorType, f: Callable[..., TensorType]
) -> Callable[..., Tuple[TensorType, TensorType]]:
    ...


@overload
def value_and_grad_fn(
    t: TensorType, f: Callable[..., TensorType], has_aux: Literal[False]
) -> Callable[..., Tuple[TensorType, TensorType]]:
    ...


@overload
def value_and_grad_fn(
    t: TensorType, f: Callable[..., Tuple[TensorType, Any]], has_aux: Literal[True]
) -> Callable[..., Tuple[TensorType, Any, TensorType]]:
    ...


def value_and_grad_fn(t: Any, f: Any, has_aux: bool = False) -> Any:
    return t._value_and_grad_fn(f, has_aux=has_aux)


def value_and_grad(
    f: Callable[..., TensorType], t: TensorType, *args: Any, **kwargs: Any
) -> Tuple[TensorType, TensorType]:
    return t.value_and_grad(f, *args, **kwargs)


def value_aux_and_grad(
    f: Callable[..., Tuple[TensorType, Any]], t: TensorType, *args: Any, **kwargs: Any
) -> Tuple[TensorType, Any, TensorType]:
    return t.value_aux_and_grad(f, *args, **kwargs)


def reshape(t: TensorType, shape: Union[Shape, int]) -> TensorType:
    return t.reshape(shape)


def take_along_axis(t: TensorType, indices: TensorType, axis: int) -> TensorType:
    return t.take_along_axis(indices, axis)


def flatten(t: TensorType, start: int = 0, end: int = -1) -> TensorType:
    return t.flatten(start=start, end=end)


def inv(t: TensorType) -> TensorType:
    return t.inv()


def round(t: TensorType) -> TensorType:
    return t.round()


def ceil(t: TensorType) -> TensorType:
    return t.ceil()


def floor(t: TensorType) -> TensorType:
    return t.floor()
from .tensor import TensorType


def kl_div_with_logits(
    logits_p: TensorType, logits_q: TensorType, axis: int = -1, keepdims: bool = False
) -> TensorType:
    log_p = logits_p.log_softmax(axis=axis)
    log_q = logits_q.log_softmax(axis=axis)
    p = logits_p.softmax(axis=-1)
    return (p * (log_p - log_q)).sum(axis=axis, keepdims=keepdims)