In this vignette, we consider optimal tensor transport (OTT), which is an extension of OT to be able to handle tensors of any order by learning possibly multiple transport plans.
Here, we reproduce the experiments in the original paper (Kerdoncuff 2022). For the details of the methodology, see the original paper.
D <- 1 A <- 1 Is <- c(4) Ks <- c(7) f <- c(1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { arrX[i1] <- i1 } for (k1 in 1:Ks[1]) { arrY[k1] <- k1 }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) plot(arrX, type=“h”, col=“black”, main=“arrX”) plot(arrY, type=“h”, col=“black”, main=“arrY”)
D <- 2 A <- 2 Is <- c(4, 5) Ks <- c(7, 8) f <- c(1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)
par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)
D <- 2 A <- 1 Is <- c(4, 4) Ks <- c(6, 6) f <- c(1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)
D <- 3 A <- 1 Is <- c(4, 4, 4) Ks <- c(6, 6, 6) f <- c(1, 1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)
D <- 3 A <- 3 Is <- c(4, 5, 6) Ks <- c(7, 8, 9) f <- c(1, 2, 3) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)
par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX[,,2], main=“arrX[,,2]”) .show_matrix(arrY[,,2], main=“arrY[,,2]”)
par(mfrow=c(3, 2)) plot(ps[[3]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[3]]”) plot(qs[[3]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[3]]”) .show_matrix(out$Ts[[3]], main=“Ts[[3]]”) .show_matrix(arrX[,,3], main=“arrX[,,3]”) .show_matrix(arrY[,,3], main=“arrY[,,3]”)
D <- 3 A <- 2 Is <- c(4, 4, 5) Ks <- c(6, 6, 7) f <- c(1, 1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)
par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX[,,2], main=“arrX[,,2]”) .show_matrix(arrY[,,2], main=“arrY[,,2]”)
## R version 4.6.0 (2026-04-24)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.4 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## time zone: Etc/UTC
## tzcode source: system (glibc)
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] rTensor_1.5.0 otTensor_0.99.0 rmarkdown_2.31
##
## loaded via a namespace (and not attached):
## [1] digest_0.6.39 R6_2.6.1 fastmap_1.2.0 xfun_0.57
## [5] maketools_1.3.2 cachem_1.1.0 knitr_1.51 htmltools_0.5.9
## [9] buildtools_1.0.0 lifecycle_1.0.5 cli_3.6.6 sass_0.4.10
## [13] jquerylib_0.1.4 compiler_4.6.0 sys_3.4.3 tools_4.6.0
## [17] evaluate_1.0.5 bslib_0.10.0 yaml_2.3.12 jsonlite_2.0.0
## [21] rlang_1.2.0