# We attempt to use operators defined for '.matrix' in the 'ScaledMatrixSeed'.
# This avoids expensive modifications such as loss of sparsity.
# Centering and scaling are factored out into separate operations.
#
# We assume that the non-'ScaledMatrix' argument is small and can be modified cheaply.
# We also assume that the matrix product is small and can be modified cheaply.
# This allows centering and scaling to be applied *after* multiplication.
#
# Here are some ground rules for how these functions must work:
#
#  - NO arithmetic operations shall be applied to a ScaledMatrix.
#    This includes nested ScaledMatrices that are present in '.matrix'.
#    Such operations collapses the ScaledMatrix to a DelayedMatrix, 
#    resulting in slow block processing during multiplication.
# 
#  - NO addition/subtraction operations shall be applied to '.matrix'.
#    This is necessary to avoid loss of sparsity for sparse '.matrix',
#    as well as to avoid block processing for ScaledMatrix '.matrix'.
# 
#  - NO division/multiplication operations should be applied to '.matrix'.
#    This is largely a consequence of the first point above.
#    Exceptions are only allowed when this is unavoidable, e.g., in '.internal_tcrossprod'.
#
#  - NO calling of %*% or (t)crossprod on a ScaledMatrix of the same nesting depth as an input ScaledMatrix.
#    Internal multiplication should always be applied to '.matrix', to avoid infinite S4 recursion.
#    Each method call should strip away one nesting level, i.e., operate on the seed.
#    Exceptions are allowed for dual ScaledMatrix multiplication,
#    where one argument is allowed to be of the same depth.

#' @export
#' @importFrom Matrix t
#' @importFrom DelayedArray seed DelayedArray
setMethod("%*%", c("ScaledMatrix", "ANY"), function(x, y) {
    x_seed <- seed(x)
    if (is_transposed(x_seed)) {
        out <- t(.leftmult_ScaledMatrix(t(y), x_seed))
    } else {
        out <- .rightmult_ScaledMatrix(x_seed, y)
    }
    DelayedArray(out)
})

#' @importFrom DelayedArray sweep
.rightmult_ScaledMatrix <- function(x_seed, y) {
    if (use_scale(x_seed)) {
        y <- y / get_scale(x_seed)
    }

    out <- as.matrix(get_matrix2(x_seed) %*% y)

    if (use_center(x_seed)) {
        out <- sweep(out, 2, as.numeric(get_center(x_seed) %*% y), "-", check.margin=FALSE)
    }

    out
}

#' @export
#' @importFrom Matrix t 
#' @importFrom DelayedArray seed DelayedArray
setMethod("%*%", c("ANY", "ScaledMatrix"), function(x, y) {
    y_seed <- seed(y)
    if (is_transposed(y_seed)) {
        if (!is.null(dim(x))) {
            # Vectors don't quite behave as 1-column matrices here.
            # so we need to be a bit more careful.
            x <- t(x) 
        }
        out <- t(.rightmult_ScaledMatrix(y_seed, x))
    } else {
        out <- .leftmult_ScaledMatrix(x, y_seed)
    }
    DelayedArray(out)
})

#' @importFrom Matrix rowSums
#' @importFrom DelayedArray sweep
.leftmult_ScaledMatrix <- function(x, y_seed) { 
    out <- as.matrix(x %*% get_matrix2(y_seed))

    if (use_center(y_seed)) {
        if (is.null(dim(x))) {
            out <- out - get_center(y_seed) * sum(x)
        } else {
            out <- out - outer(rowSums(x), get_center(y_seed), "*")
        }
    }

    if (use_scale(y_seed)) {
        out <- sweep(out, 2, get_scale(y_seed), "/", check.margin=FALSE)
    }

    out
}

#' @export
#' @importFrom DelayedArray seed DelayedArray
setMethod("%*%", c("ScaledMatrix", "ScaledMatrix"), function(x, y) {
    x_seed <- seed(x)
    y_seed <- seed(y)
    res <- .dual_mult_dispatcher(x_seed, y_seed, is_transposed(x_seed), is_transposed(y_seed))
    DelayedArray(res)
})

#' @importFrom Matrix t
.dual_mult_dispatcher <- function(x_seed, y_seed, x_trans, y_trans) {
    if (!x_trans) {
        if (!y_trans) {
            res <- .multiply_u2u(x_seed, y_seed)
        } else {
            res <- .multiply_u2t(x_seed, y_seed)
        }
    } else {
        if (!y_trans) {
            res <- .multiply_t2u(x_seed, y_seed)
        } else {
            res <- .multiply_u2u(y_seed, x_seed)
            res <- t(res)
        }
    }
    res
}

###################################
# ScMat %*% ScMat utilities.

# We do not implement ScMat %*% ScMat in terms of left/right %*%.
# This would cause scaling to be applied on one of the ScMats,
# collapsing it into a DelayedMatrix. Subsequent multiplication 
# would use block processing, which would be too slow.

#' @importFrom Matrix drop rowSums
#' @importFrom DelayedArray sweep
.multiply_u2u <- function(x_seed, y_seed) 
# Considering the problem of (X - C_x)S_x (Y - C_y)S_y.
{
    # Computing X S_x Y S_y
    x0 <- get_matrix2(x_seed)
    if (use_scale(x_seed)) {
        x0 <- ScaledMatrix(x0, scale=get_scale(x_seed))
    } 

    result <- as.matrix(x0 %*% get_matrix2(y_seed))
    if (use_scale(y_seed)) {
        result <- sweep(result, 2, get_scale(y_seed), "/", check.margin=FALSE)
    }

    # Computing C_x S_x Y S_y, and subtracting it from 'result'.
    if (use_center(x_seed)) {
        x.center <- get_center(x_seed)
        if (use_scale(x_seed)) {
            x.center <- x.center / get_scale(x_seed)
        }

        component2 <- drop(x.center %*% get_matrix2(y_seed))
        if (use_scale(y_seed)) {
            component2 <- component2 / get_scale(y_seed)
        }

        result <- sweep(result, 2, component2, "-", check.margin=FALSE)
    }

    # Computing C_x S_x C_y S_y, and adding it to 'result'.
    if (use_center(x_seed) && use_center(y_seed)) {
        x.center <- get_center(x_seed)
        if (use_scale(x_seed)) {
            x.center <- x.center / get_scale(x_seed)
        }

        y.center <- get_center(y_seed)
        if (use_scale(y_seed)) {
            y.center <- y.center / get_scale(y_seed)
        }

        component4 <- sum(x.center) * y.center
        result <- sweep(result, 2, component4, "+", check.margin=FALSE)
    }

    # Computing X S_x C_y S_y, and subtracting it from 'result'.
    # This is done last to avoid subtracting large values.
    if (use_center(y_seed)) {
        y.center <- get_center(y_seed)
        if (use_scale(y_seed)) {
            y.center <- y.center / get_scale(y_seed)
        }

        component3 <- outer(rowSums(x0), y.center)
        result <- result - component3
    }

    result
}

#' @importFrom Matrix tcrossprod drop
#' @importFrom DelayedArray sweep
.multiply_u2t <- function(x_seed, y_seed) 
# Considering the problem of (X - C_x)S_x S_y(Y' - C_y')
{
    # Computing X S_x S_y Y'
    x0 <- get_matrix2(x_seed)
    if (use_scale(x_seed) || use_scale(y_seed)) {
        scaling <- 1
        if (use_scale(x_seed)) {
            scaling <- scaling * get_scale(x_seed)
        }
        if (use_scale(y_seed)) {
            scaling <- scaling * get_scale(y_seed)
        }
        x0 <- ScaledMatrix(x0, scale=scaling)
    }
    result <- as.matrix(tcrossprod(x0, get_matrix2(y_seed)))

    # Computing C_x S_x S_y Y', and subtracting it from 'result'.
    if (use_center(x_seed)) {
        x.center <- get_center(x_seed)
        if (use_scale(x_seed)) {
            x.center <- x.center / get_scale(x_seed)
        }
        if (use_scale(y_seed)) {
            x.center <- x.center / get_scale(y_seed)
        }

        component2 <- drop(tcrossprod(x.center, get_matrix2(y_seed)))
        result <- sweep(result, 2, component2, "-", check.margin=FALSE)
    }

    # Computing C_x S_x S_y C_y', and adding it to 'result'.
    if (use_center(x_seed) && use_center(y_seed)) {
        x.center <- get_center(x_seed)
        if (use_scale(x_seed)) {
            x.center <- x.center / get_scale(x_seed)
        }

        y.center <- get_center(y_seed)
        if (use_scale(y_seed)) {
            y.center <- y.center / get_scale(y_seed)
        }

        component4 <- sum(x.center*y.center)
        result <- result + component4
    }

    # Computing X S_x S_y C_y', and subtracting it from 'result'.
    # This is done last to avoid subtracting large values.
    if (use_center(y_seed)) {
        component3 <- drop(x0 %*% get_center(y_seed))
        result <- result - component3
    }

    result
}

#' @importFrom Matrix crossprod colSums
#' @importFrom DelayedArray sweep
.multiply_t2u <- function(x_seed, y_seed) 
# Considering the problem of S_x(X' - C_x') (Y - C_y)S_y
{
    # C mputing X' Y 
    x0 <- get_matrix2(x_seed)
    y0 <- get_matrix2(y_seed)
    result <- as.matrix(crossprod(x0, y0))

    # Computing C_x' Y, and subtracting it from 'result'.
    if (use_center(x_seed)) {
        x.center <- get_center(x_seed)
        component2 <- outer(x.center, colSums(y0))
        result <- result - component2
    }

    # Computing C_x' C_y, and adding it to 'result'.
    if (use_center(x_seed) && use_center(y_seed)) {
        x.center <- get_center(x_seed)
        y.center <- get_center(y_seed)
        component4 <- outer(x.center, y.center) * nrow(y0)
        result <- result + component4
    }

    # Computing X' C_y, and subtracting it from 'result'.
    # This is done last to avoid subtracting large values.
    if (use_center(y_seed)) {
        component3 <- outer(colSums(x0), get_center(y_seed))
        result <- result - component3
    }

    if (use_scale(x_seed)) {
        result <- result / get_scale(x_seed)
    } 
    if (use_scale(y_seed)) {
        result <- sweep(result, 2, get_scale(y_seed), "/", check.margin=FALSE)
    }

    result
}

###################################
# Cross-product. 

# Technically, we could implement this in terms of '%*%',
# but we use specializations to exploit native crossprod() for '.matrix',
# which is probably more efficient.

#' @export
#' @importFrom Matrix crossprod
#' @importFrom DelayedArray seed DelayedArray
setMethod("crossprod", c("ScaledMatrix", "missing"), function(x, y) {
    x_seed <- seed(x)
    if (is_transposed(x_seed)) {
        # No need to t(), the output is symmetric anyway. 
        out <- .tcp_ScaledMatrix(x_seed)
    } else {
        out <- .cross_ScaledMatrix(x_seed)
    }

    DelayedArray(out)
})

#' @importFrom Matrix crossprod colSums
#' @importFrom DelayedArray sweep
.cross_ScaledMatrix <- function(x_seed) {
    x0 <- get_matrix2(x_seed)
    out <- as.matrix(crossprod(x0))

    if (use_center(x_seed)) {
        centering <- get_center(x_seed)
        colsums <- colSums(x0)

        # Minus, then add, then minus, to mitigate cancellation.
        out <- out - outer(centering, colsums)
        out <- out + outer(centering, centering) * nrow(x0)
        out <- out - outer(colsums, centering)
    }

    if (use_scale(x_seed)) {
        scaling <- get_scale(x_seed)
        out <- sweep(out / scaling, 2, scaling, "/", check.margin=FALSE)
    }
    out
}

#' @export
#' @importFrom Matrix crossprod
#' @importFrom DelayedArray seed DelayedArray
setMethod("crossprod", c("ScaledMatrix", "ANY"), function(x, y) {
    x_seed <- seed(x)
    if (is_transposed(x_seed)) {
        out <- .rightmult_ScaledMatrix(x_seed, y)
    } else {
        out <- .rightcross_ScaledMatrix(x_seed, y)
    }
    DelayedArray(out)
})

#' @importFrom Matrix crossprod colSums
.rightcross_ScaledMatrix <- function(x_seed, y) {
    out <- as.matrix(crossprod(get_matrix2(x_seed), y))

    if (use_center(x_seed)) {
        if (is.null(dim(y))) {
            out <- out - get_center(x_seed) * sum(y)
        } else {
            out <- out - outer(get_center(x_seed), colSums(y))
        }
    }
    
    if (use_scale(x_seed)) {
        out <- out / get_scale(x_seed)
    }

    out
}

#' @export
#' @importFrom Matrix crossprod
#' @importFrom DelayedArray seed DelayedArray
setMethod("crossprod", c("ANY", "ScaledMatrix"), function(x, y) {
    y_seed <- seed(y)
    if (is_transposed(y_seed)) {
        out <- t(.rightmult_ScaledMatrix(y_seed, x))
    } else {
        out <- .leftcross_ScaledMatrix(x, y_seed)
    }
    DelayedArray(out)
})

#' @importFrom Matrix crossprod colSums
#' @importFrom DelayedArray sweep
.leftcross_ScaledMatrix <- function(x, y_seed) {
    out <- as.matrix(crossprod(x, get_matrix2(y_seed)))

    if (use_center(y_seed)) {
        if (is.null(dim(x))) {
            out <- sweep(out, 2, sum(x) * get_center(y_seed), "-", check.margin=FALSE)
        } else {
            out <- out - outer(colSums(x), get_center(y_seed))
        }
    }

    if (use_scale(y_seed)) {
        out <- sweep(out, 2, get_scale(y_seed), "/", check.margin=FALSE)
    }

    out
}

#' @export
#' @importFrom Matrix crossprod
#' @importFrom DelayedArray DelayedArray seed
setMethod("crossprod", c("ScaledMatrix", "ScaledMatrix"), function(x, y) {
    x_seed <- seed(x)
    y_seed <- seed(y)
    res <- .dual_mult_dispatcher(x_seed, y_seed, !is_transposed(x_seed), is_transposed(y_seed))
    DelayedArray(res)
})

###################################
# Transposed cross-product. 

# Technically, we could implement this in terms of '%*%',
# but we use specializations to exploit native tcrossprod() for '.matrix',
# which is probably more efficient.

#' @export
#' @importFrom Matrix tcrossprod
#' @importFrom DelayedArray seed DelayedArray sweep
setMethod("tcrossprod", c("ScaledMatrix", "missing"), function(x, y) {
    x_seed <- seed(x)
    if (is_transposed(x_seed)) {
        out <- .cross_ScaledMatrix(x_seed)
    } else {
        out <- .tcp_ScaledMatrix(x_seed)
    }
    DelayedArray(out)
})

#' @importFrom Matrix tcrossprod
.tcp_ScaledMatrix <- function(x_seed) {
    x0 <- get_matrix2(x_seed)

    if (use_scale(x_seed)) {
        out <- as.matrix(.internal_tcrossprod(x0, get_scale(x_seed)))
    } else {
        out <- as.matrix(tcrossprod(x0))
    }
    
    if (use_center(x_seed)) {
        centering <- get_center(x_seed)

        if (use_scale(x_seed)) {
            centering <- centering / get_scale(x_seed)
            extra <- centering / get_scale(x_seed)
        } else {
            extra <- centering
        }
            
        # With scaling, the use of 'extra' mimics sweep(x0, 2, get_scale(x), "/"),
        # except that the scaling is applied to 'centering' rather than directly to 'x0'.
        # Without scaling, 'extra' and 'centering' are interchangeable.
        component <- tcrossprod(extra, x0)

        # Minus, then add, then minus, to mitigate cancellation.
        out <- sweep(out, 2, as.numeric(component), "-", check.margin=FALSE)
        out <- out + sum(centering^2)
        out <- out - as.numeric(x0 %*% extra)
    }

    out
}

#' @export
#' @importFrom Matrix tcrossprod t
#' @importFrom DelayedArray seed DelayedArray sweep
setMethod("tcrossprod", c("ScaledMatrix", "ANY"), function(x, y) {
    if (is.null(dim(y))) { # for consistency with base::tcrossprod.
        stop("non-conformable arguments")
    }

    x_seed <- seed(x)
    if (is_transposed(x_seed)) {
        out <- t(.leftmult_ScaledMatrix(y, x_seed))
    } else {
        out <- .righttcp_ScaledMatrix(x_seed, y)
    }
    DelayedArray(out)
})

#' @importFrom Matrix tcrossprod
.righttcp_ScaledMatrix <- function(x_seed, y) {
    if (use_scale(x_seed)) {
        # 'y' cannot be a vector anymore, due to the check above.
        y <- sweep(y, 2, get_scale(x_seed), "/", check.margin=FALSE)
    }

    out <- as.matrix(tcrossprod(get_matrix2(x_seed), y))

    if (use_center(x_seed)) {
        out <- sweep(out, 2, as.numeric(tcrossprod(get_center(x_seed), y)), "-", check.margin=FALSE)
    }

    out
}

#' @export
#' @importFrom Matrix tcrossprod t
#' @importFrom DelayedArray seed DelayedArray
setMethod("tcrossprod", c("ANY", "ScaledMatrix"), function(x, y) {
    y_seed <- seed(y) 
    if (is_transposed(y_seed)) {
        out <- .leftmult_ScaledMatrix(x, y_seed)
    } else {
        out <- .lefttcp_ScaledMatrix(x, y_seed)
    }
    DelayedArray(out)
})

#' @importFrom Matrix tcrossprod 
.lefttcp_ScaledMatrix <- function(x, y_seed) {
    if (use_scale(y_seed)) {
        if (is.null(dim(x))) {
            x <- x / get_scale(y_seed)
        } else { 
            x <- sweep(x, 2, get_scale(y_seed), "/", check.margin=FALSE)
        }
    }

    out <- as.matrix(tcrossprod(x, get_matrix2(y_seed)))

    if (use_center(y_seed)) {
        out <- out - as.numeric(x %*% get_center(y_seed))
    }
    out
}

#' @export
#' @importFrom Matrix tcrossprod
#' @importFrom DelayedArray DelayedArray seed
setMethod("tcrossprod", c("ScaledMatrix", "ScaledMatrix"), function(x, y) {
    x_seed <- seed(x)
    y_seed <- seed(y)
    res <- .dual_mult_dispatcher(x_seed, y_seed, is_transposed(x_seed), !is_transposed(y_seed))
    DelayedArray(res)
})

###################################
# Extra code for corner-case calculations of the transposed cross-product.

#' @importFrom DelayedArray seed DelayedArray
.update_scale <- function(x, s) {
    x_seed <- seed(x)
    if (use_scale(x_seed)) {
        s <- s * get_scale(x_seed)
    }
    x_seed@scale <- s
    x_seed@use_scale <- TRUE
    DelayedArray(x_seed)
}

#' @importFrom Matrix tcrossprod 
#' @importFrom methods is
#' @importFrom DelayedArray seed
.internal_tcrossprod <- function(x, scale.) 
# Computes tcrossprod(sweep(x, 2, scale, "/")) when 'x' is a matrix-like object.
# 'scale' can be assumed to be non-NULL here.
# This will always return a dense ordinary matrix.
{
    if (!is(x, "ScaledMatrix")) {
        x <- sweep(x, 2, scale., "/", check.margin=FALSE) 
        return(as.matrix(tcrossprod(x)))
    }

    x_seed <- seed(x)
    if (!is_transposed(x_seed)) {
        x <- .update_scale(x, scale.)
        return(as.matrix(tcrossprod(x)))
    }

    inner <- get_matrix2(x_seed)
    if (is(inner, "ScaledMatrix")) {
        if (is_transposed(seed(inner))) {
            component1 <- as.matrix(crossprod(.update_scale(inner, scale.)))
        } else {
            component1 <- .internal_tcrossprod(t(inner), scale.) # recurses. 
        }
    } else {
        component1 <- as.matrix(crossprod(inner/scale.))
    }
           
    if (use_center(x_seed)) {
        centering <- get_center(x_seed)
        component2 <- .internal_mult_special(centering, scale., inner)
        component3 <- t(component2)
        component4 <- outer(centering, centering) * sum(1/scale.^2)
        final <- (component1 - component2) + (component4 - component3)
    } else {
        final <- component1
    }

    if (use_scale(x_seed)) {
        x.scale <- get_scale(x_seed)
        final <- final / x.scale
        final <- sweep(final, 2, x.scale, "/", check.margin=FALSE) 
    }

    final 
}

#' @importFrom methods is
#' @importFrom DelayedArray seed
.internal_mult_special <- function(center, scale., Z)
# Computes C^T * S^2 * Z where C is a matrix of 'centers' copied byrow=TRUE;
# S is a diagonal matrix filled with '1/scale'; and 'Z' is a ScaledMatrix.
# This will always return a dense ordinary matrix.
{
    if (!is(Z, "ScaledMatrix")) {
        return(outer(center, colSums(Z / scale.^2)))
    }
    
    Z_seed <- seed(Z)
    if (is_transposed(Z_seed)) {
        Z <- .update_scale(Z, scale.^2)
        return(outer(center, colSums(Z)))
    }

    output <- .internal_mult_special(center, scale., get_matrix2(Z_seed)) # recurses.

    if (use_center(Z_seed)) {
        output <- output - outer(center, get_center(Z_seed)) * sum(1/scale.^2)
    }

    if (use_scale(Z_seed)) {
        output <- sweep(output, 2, get_scale(Z_seed), "/")
    }

    output
}