R/evalHeatmap.R
1e2ccd61
 #' evalHeatmap
 #' 
4ccba5fc
 #' General heatmap representation of aggregated evaluation results. By default,
 #' the actual metric values are printed in the cells, and while the coloring is
 #' determined by \code{\link{colCenterScale}} (number of matrix-median absolute
 #'  deviations from the column means). Unless the total number of analyses is
 #'  small, it is strongly recommended to use the `agg.by` argument to limit the
 #'  size and improve the readability of the heatmap.
1e2ccd61
 #'
 #' @param res Aggregated pipeline results (i.e. the output of `runPipeline` or
 #' `aggregateResults`)
4ccba5fc
 #' @param step Name of the step for which to plot the evaluation results. If 
 #' unspecified, will use the latest step that has evaluation results.
 #' @param what What metric to plot.
1e2ccd61
 #' @param what2 If the step has more than one benchmark data.frame, which one
4ccba5fc
 #' to use. The function will attempt to guess that automatically based on 
 #' `what`, and will notify in case of ambiguity.
1e2ccd61
 #' @param agg.by Aggregate results by these columns (default no aggregation)
 #' @param agg.fn Function for aggregation (default mean)
4ccba5fc
 #' @param filterExpr An optional filtering expression based on the columns of 
 #' the target dataframe, (e.g. `filterExpr=param1=="value1"`).
 #' @param scale Controls the scaling of the columns for the color mapping. Can 
 #' either be a logical (TRUE will use NA-safe column z-scores, FALSE will not 
 #' scale) or a function performing the scaling. The default uses the
 #'  `colCenterScale` function (per-column centering, but per-matrix variance
 #'  scaling).
1e2ccd61
 #' @param value_format Format for displaying cells' values (use 
 #' `value_format=""` to disable)
 #' @param reorder_rows Logical; whether to sort rows (default FALSE). The row 
 #' names themselves can also be passed to specify an order, or a 
 #' `ComplexHeatmap`.
c8f1895f
 #' @param row_split Optional column (included in `agg.by`) by which to split
9c85399b
 #' the rows. Alternatively, an expression using the columns (retained after
 #' aggregation) can be passed.
1e2ccd61
 #' @param show_heatmap_legend Passed to `Heatmap` (default FALSE)
 #' @param show_column_names Passed to `Heatmap` (default FALSE)
86023f6b
 #' @param col Colors for the heatmap, or a color-mapping function as produced by
 #' `colorRamp2`. If passing a vector of colors and the data is scaled, there 
 #' should be an odd number of colors. By default, will apply linear mapping (if
8e382601
 #' the data is not scaled) or signed sqrt mapping (if scaled) on the 
86023f6b
 #' `viridisLite::inferno` palette.
c8f1895f
 #' @param font_factor A scaling factor applied to fontsizes (default 1)
1e2ccd61
 #' @param value_cols A vector of length 2 indicating the colors of the values
 #' (above and below the mean), if printed
 #' @param title Plot title
c8f1895f
 #' @param shortNames Logical; whether to use short row names (with only
 #' the parameter values instead of the parameter name and value pairs), default
 #' TRUE.
066bafc5
 #' @param name Heatmap name (e.g. used for the legend)
1e2ccd61
 #' @param anno_legend Logical; whether to plot the legend for the datasets
 #' @param ... Passed to `Heatmap`
 #' 
 #' @return A Heatmap
 #' @export
066bafc5
 #' @importFrom viridisLite inferno
4ccba5fc
 #' @importFrom grid gpar
1e2ccd61
 #' @import ComplexHeatmap
 #' @examples
4ccba5fc
 #' data("exampleResults", package="pipeComp")
066bafc5
 #' evalHeatmap( exampleResults, step="clustering", what=c("ARI","MI","min_pr"), 
4ccba5fc
 #'              agg.by=c("filt", "norm"), row_split = "norm" ) +
066bafc5
 #' evalHeatmap( exampleResults, step="clustering", what="ARI", 
 #'              agg.by=c("filt", "norm"), filterExpr=n_clus==true.nbClusts, 
 #'              name="ARI at true k", title="ARI at
4ccba5fc
 #' true K" )
1e2ccd61
 evalHeatmap <- function( res, step=NULL, what, what2=NULL, agg.by=NULL, 
4ccba5fc
                          agg.fn=mean, filterExpr=NULL, scale="colCenterScale", 
                          value_format="%.2f", reorder_rows=FALSE, 
                          show_heatmap_legend=FALSE, show_column_names=FALSE, 
9c85399b
                          col=NULL, font_factor=0.9, 
4ccba5fc
                          row_split=NULL, shortNames=TRUE,
1e2ccd61
                          value_cols=c("black","white"), title=NULL, 
066bafc5
                          name=NULL, anno_legend=TRUE, ...){
1e2ccd61
   pd <- NULL
   if(is(res,"SimpleList")) pd <- metadata(res)$PipelineDefinition
b04a719a
   #if(is.null(pd)) stop("Could not find the PipelineDefinition.")
4ccba5fc
   if(is.null(step)){
     step <- rev(names(res$evaluation))[1]
     if(length(res$evaluation)>1) message("Using step '",step,"'.")
   }
1e2ccd61
   step <- match.arg(step, names(res$evaluation))
   if(length(what)>1){
4ccba5fc
     title <- lapply(seq_along(what), FUN=function(i){
       tryCatch( title[[i]], error=function(e) what[[i]] )
     })
     fcall <- match.call()
     H <- NULL
     for(i in seq_along(what)){
       fcall$what <- what[[i]]
       fcall$title <- title[[i]]
       H <- H + eval(fcall)
     }
1e2ccd61
     return(H)
   }
   if(is(res,"SimpleList") && "evaluation" %in% names(res)) res <- res$evaluation
   if(is(res,"list") && step %in% names(res)) res <- res[[step]]
   if(is(res,"list")){
     if(is.null(what2)){
       # try to guess the table
c8f1895f
       test <- which(vapply( res, FUN=function(x) what %in% colnames(x), 
                             logical(1) ))
1e2ccd61
       if(length(test)==0) stop("`what2` is undefined and could not be guessed.")
       if(length(test)>1) message("`what2` is undefined, and `what` is present ",
4ccba5fc
         "in more than one results data.frame (",
         paste(paste0("`",names(res)[test],"`"), collapse=", "), "). ",
         names(res)[test[1]], "' will be used.")
1e2ccd61
       what2 <- names(res)[test[1]]
     }
     what2 <- match.arg(what2, names(res))
     res <- res[[what2]]
   }
b04a719a
   if(!is.null(pd)){
     what_options <- setdiff( colnames(res), 
1e2ccd61
                            c("dataset",unlist(arguments(pd))) )
b04a719a
     what <- match.arg(what, choices=what_options)
   }
1e2ccd61
   res <- .prepRes(res, what=what, agg.by=agg.by, pipDef=pd, 
4ccba5fc
                   filt=substitute(filterExpr), shortNames=shortNames, 
                   returnParams=TRUE)
1e2ccd61
   pp <- res$pp
4ccba5fc
   res <- res$res
   res2 <- .doscale(res, param=scale)
   ro <- .dosort(res2, reorder_rows)
1e2ccd61
   res2 <- as.matrix(res2)
86023f6b
   cellfn <- .getCellFn(res, res2, value_format, value_cols, font_factor, 
                        scaled=!(is.logical(scale) && !scale) )
1e2ccd61
   if(is.null(title)) title <- gsub("\\.","\n",what)
9c85399b
   suppressWarnings({
     if(!tryCatch(is.null(row_split), error=function(e) FALSE)){
       if(tryCatch(is.character(row_split), error=function(e) FALSE)){
         if(row_split %in% colnames(pp)){
           row_split <- pp[,row_split]
         }else{
           warning("`row_split` wasn't found and will be ignored.")
           row_split <- NULL
         }
       }else{
         row_split <- eval(substitute(row_split), envir=pp)
       }
1e2ccd61
     }
9c85399b
   })
066bafc5
   if(is.null(name)) name <- what
86023f6b
   if(is.null(col) || !is.function(col))
     col <- .defaultColorMapping(res2, center=!(is.logical(scale) && !scale),
                                 cols=col)
066bafc5
   Heatmap( res2, name=name, cluster_rows=FALSE, cluster_columns=FALSE, 
1e2ccd61
            show_heatmap_legend=show_heatmap_legend, row_order=ro,
c8f1895f
            bottom_annotation=.ds_anno(colnames(res),anno_legend,font_factor), 
            show_column_names=show_column_names, cell_fun=cellfn, col=col,
1e2ccd61
            column_title=title, row_split=row_split,
c8f1895f
            row_title_gp = gpar(fontsize = (14*font_factor)),
            column_title_gp = gpar(fontsize = (14*font_factor)),
            row_names_gp = gpar(fontsize = (12*font_factor)),
86023f6b
            column_names_gp = gpar(fontsize = (12*font_factor)), ... )
1e2ccd61
 }
 
4ccba5fc
 .doscale <- function(res, param){
   if(is.null(param)) param <- colCenterScale
   if(is.logical(param) && !param) return(res)
   if(is.logical(param) && param) return(.safescale(res))
   if(is.function(param)) return(param(res))
   if(is.character(param) && is.function(get(param))) return(get(param)(res))
   warning("Unknown scaling parameter, will use column z-scores.")
   .safescale(res)
 }
 
b04a719a
 #' @import ComplexHeatmap
4ccba5fc
 .dosort <- function(res, reorder_rows){
   if(is(reorder_rows, "Heatmap")){
b04a719a
     ro <- row_order(reorder_rows)
4ccba5fc
   }else{
     if(length(reorder_rows)>1){
       ro <- reorder_rows
     }else{
       if(reorder_rows){
         ro <- order(rowMeans(res, na.rm=TRUE), decreasing=TRUE)
       }else{
         ro <- seq_len(nrow(res))
       }
     }
   }
   if(!is.numeric(ro))
     ro <- vapply(ro, FUN=function(x) which(row.names(res)==x), integer(1))
   ro
 }
 
1e2ccd61
 # NA-robust scale
 .safescale <- function(x){
   if(!is.null(dim(x))){
     y <- apply(x,2,.safescale)
     row.names(y) <- row.names(x)
     return(y)
   }
   if(all(is.na(x))) return(x)
   if(sd(x,na.rm=TRUE)>0) return(base::scale(x))
   if(sum(!is.na(x))==0) return(base::scale(as.numeric(!is.na(x))))
   rep(0,length(x))
 }
 
86023f6b
 .defaultColorMapping <- function(x, center=TRUE, cols=NULL){
   if(is.null(cols)) cols <- viridisLite::inferno(101)
   if(!center) return(cols)
8e382601
   q <- max(abs(range(x, na.rm=TRUE)))
86023f6b
   b <- c( -seq(from=sqrt(q), to=0, length.out=floor(length(cols)/2)+1)^2,
           seq(from=0, to=sqrt(q), length.out=floor(length(cols)/2)+1)[-1]^2 )
   colorRamp2(b, cols)
8e382601
 }
 
4ccba5fc
 #' colCenterScale
 #'
 #' Matrix scaling by centering columns separately and then performing variance 
 #' scaling on the whole matrix, in a NA-robust fashion. With the default 
 #' arguments, the output will be the number of (matrix-)median absolute 
 #' deviations from the column-median.
 #'
 #' @param x A numeric matrix.
 #' @param centerFn The function for calculating centers. Should accept the 
 #' `na.rm` argument. E.g. `centerFn=mean` or `centerFn=median`.
 #' @param scaleFn The function for calculating the (matrix-wise) scaling
 #' factor. Should accept the `na.rm` argument. Default `median(abs(x))`.
 #'
 #' @return A scaled matrix of the same dimensions as `x`.
 #' @export
 #'
 #' @examples
 #' # random data with column mean differences
 #' d <- cbind(A=rnorm(5, 10, 2), B=rnorm(5, 20, 2), C=rnorm(5,30, 2))
 #' colCenterScale(d)
 colCenterScale <- function(x, centerFn=median, 
                            scaleFn=function(x,na.rm) median(abs(x),na.rm=na.rm)
                            ){
   if(is.null(dim(x))) stop("`x` should be a numeric matrix or data.frame.")
   centers <- apply(x, MARGIN=2, na.rm=TRUE, FUN=centerFn)
   x2 <- t(t(x)-centers)
   x2 <- x2/scaleFn(x2, na.rm=TRUE)
   if(all(is.na(x2))) x2[is.na(x2)] <- 0
   x2
 }
 
1e2ccd61
 #' @importFrom ComplexHeatmap HeatmapAnnotation
c8f1895f
 .ds_anno <- function(x, legend=TRUE, font_factor=1){
   y <- vapply( strsplit(gsub("stepElapsed\\.","",x)," "),
                FUN=function(x) x[[1]], FUN.VALUE=character(1))
1e2ccd61
   cols <- getQualitativePalette(length(unique(y)))
   names(cols) <- sort(unique(y))
   HeatmapAnnotation(dataset=y, col=list(dataset=cols), 
c8f1895f
                     annotation_legend_param=list(
                       labels_gp=gpar(fontsize=10*font_factor),
                       title_gp=gpar(fontsize=10*font_factor, fontface = "bold")
                     ),
1e2ccd61
                     show_annotation_name = FALSE, show_legend=legend)
 }
 
 .prepRes <- function(res, what=NULL, agg.by=NULL, agg.fn=mean, pipDef=NULL, 
4ccba5fc
                      filt=NULL, long=FALSE, shortNames=FALSE, 
                      returnParams=FALSE){
1e2ccd61
   if(is.null(what) && !long) stop("`what` should be defined.")
4ccba5fc
   if(!is.null(filt) && filt!=TRUE){
     if(is.null(pipDef)) stop("Filtering requires providing a `pipDef`.")
     res <- res[which(eval(filt, res)),]
   }
1e2ccd61
   if(!is.null(agg.by)){
     agg.by2 <- c(agg.by, intersect(colnames(res),c("dataset","subpopulation")))
     if(is.null(what)){
       what <- colnames(res)[which(
c8f1895f
         vapply(res, function(x) is.integer(x) | is.numeric(x), logical(1))
1e2ccd61
       )]
       what <- setdiff(what, agg.by)
     }
     res <- aggregate(res[,what,drop=FALSE], by=res[,agg.by2,drop=FALSE], 
                      FUN=agg.fn)
     pp <- res[,agg.by,drop=FALSE]
   }else{
     if(is.null(pipDef)) stop("Either `agg.by` or `pipDef` should be given.")
     pp <- res[,intersect(colnames(res), unlist(arguments(pipDef))),drop=FALSE]
   }
   pp$method <- res$method <- .getReducedNames(pp, short=shortNames)
   if(long) return(res)
   
   if("subpopulation" %in% colnames(res))
     res$dataset <- paste(res$dataset, res$subpopulation)
   r2 <- reshape2::dcast( res, method~dataset, value.var=what,  
                          drop=FALSE, fun.aggregate=agg.fn)
   pp <- pp[!duplicated(pp$method),,drop=FALSE]
   row.names(pp) <- pp$method
   row.names(r2) <- r2[,1]
   res <- as.matrix(r2[,-1,drop=FALSE])
   if(returnParams) return(list(res=res, pp=pp[row.names(res),,drop=FALSE]))
   res
 }
 
 .getReducedNames <- function(res, short=FALSE){
   if(is.character(res)) res <- parsePipNames(res)
c8f1895f
   pp <- vapply(res, FUN=function(x) length(unique(x))>1, logical(1))
   pp <- res[,pp,drop=FALSE]
1e2ccd61
   if(!short && ncol(pp)>1){
     y <- apply(pp,1,FUN=function(x){
       x <- paste0(colnames(pp),"=",x)
       paste(x, collapse="; ")
     })
   }else{
     y <- apply(pp,1,collapse=" > ",FUN=paste)
   }
   y
 }
4ccba5fc
 
 #' @importFrom grid grid.text
c8f1895f
 .getCellFn  <- function( res, res2, value_format, cols=c("black","white"), 
86023f6b
                          font_factor=1, scaled=TRUE ){
   if(scaled){
     resmid <- median(res2, na.rm=TRUE)
   }else{
     resmid <- range(res2, na.rm=TRUE)
     resmid <- resmid[1]+(resmid[2]-resmid[1])/2
   }
   if(is.na(resmid)) resmid <- Inf
1e2ccd61
   function(j, i, x, y, width, height, fill){
     if(value_format=="" || is.null(value_format) || is.na(value_format))
       return(NULL)
     lab <- sprintf(value_format, res[i,j])
     if(!any(abs(res)>1, na.rm=TRUE)){
       lab <- gsub("^0\\.",".",lab)
       lab <- gsub("^-0\\.","-.",lab)
     } 
     lab <- gsub("^1.00$","1",lab)
     lab <- gsub("^.00$","0",lab)
86023f6b
     cols <- ifelse(res2[i,j]>=resmid,cols[1],cols[2])
c8f1895f
     grid.text(lab, x, y, gp = gpar(fontsize = 10*font_factor, col=cols))
1e2ccd61
   }
 }
4ccba5fc
 
b04a719a
 .scaledLegend <- function(){
   ComplexHeatmap::Legend(
7cc6ac5e
     col_fun=circlize::colorRamp2(seq(1,100), viridisLite::inferno(100)), 
b04a719a
     at=c(1,50,100), title="MADs", labels=c("worst","median","best"), 
     direction="horizontal" )
 }