001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016import java.util.List;
017
018import org.apache.commons.math3.complex.Complex;
019import org.apache.commons.math3.linear.Array2DRowRealMatrix;
020import org.apache.commons.math3.linear.ArrayRealVector;
021import org.apache.commons.math3.linear.CholeskyDecomposition;
022import org.apache.commons.math3.linear.ConjugateGradient;
023import org.apache.commons.math3.linear.EigenDecomposition;
024import org.apache.commons.math3.linear.LUDecomposition;
025import org.apache.commons.math3.linear.MatrixUtils;
026import org.apache.commons.math3.linear.QRDecomposition;
027import org.apache.commons.math3.linear.RealLinearOperator;
028import org.apache.commons.math3.linear.RealMatrix;
029import org.apache.commons.math3.linear.RealVector;
030import org.apache.commons.math3.linear.SingularValueDecomposition;
031
032
033public class LinearAlgebra {
034
035        private static final int CROSSOVERPOINT = 16; // point at which using slice iterators for inner loop is faster 
036
037        /**
038         * Calculate the tensor dot product over given axes. This is the sum of products of elements selected
039         * from the given axes in each dataset
040         * @param a
041         * @param b
042         * @param axisa axis dimension in a to sum over (can be -ve)
043         * @param axisb axis dimension in b to sum over (can be -ve)
044         * @return tensor dot product
045         */
046        public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int axisa, final int axisb) {
047                // this is slower for summing lengths < ~15
048                final int[] ashape = a.getShapeRef();
049                final int[] bshape = b.getShapeRef();
050                final int arank = ashape.length;
051                final int brank = bshape.length;
052                int aaxis = axisa;
053                if (aaxis < 0)
054                        aaxis += arank;
055                if (aaxis < 0 || aaxis >= arank)
056                        throw new IllegalArgumentException("Summing axis outside valid rank of 1st dataset");
057
058                if (ashape[aaxis] < CROSSOVERPOINT) { // faster to use position iteration
059                        return tensorDotProduct(a, b, new int[] {axisa}, new int[] {axisb});
060                }
061                int baxis = axisb;
062                if (baxis < 0)
063                        baxis += arank;
064                if (baxis < 0 || baxis >= arank)
065                        throw new IllegalArgumentException("Summing axis outside valid rank of 2nd dataset");
066
067                final boolean[] achoice = new boolean[arank];
068                final boolean[] bchoice = new boolean[brank];
069                Arrays.fill(achoice, true);
070                Arrays.fill(bchoice, true);
071                achoice[aaxis] = false; // flag which axes not to iterate over
072                bchoice[baxis] = false;
073
074                final boolean[] notachoice = new boolean[arank];
075                final boolean[] notbchoice = new boolean[brank];
076                notachoice[aaxis] = true; // flag which axes to iterate over
077                notbchoice[baxis] = true;
078
079                int drank = arank + brank - 2;
080                int[] dshape = new int[drank];
081                int d = 0;
082                for (int i = 0; i < arank; i++) {
083                        if (achoice[i])
084                                dshape[d++] = ashape[i];
085                }
086                for (int i = 0; i < brank; i++) {
087                        if (bchoice[i])
088                                dshape[d++] = bshape[i];
089                }
090                int dtype = DTypeUtils.getBestDType(a.getDType(), b.getDType());
091                @SuppressWarnings("deprecation")
092                Dataset data = DatasetFactory.zeros(dshape, dtype);
093
094                SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice);
095                int l = 0;
096                final int[] apos = ita.getPos();
097                while (ita.hasNext()) {
098                        SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice);
099                        final int[] bpos = itb.getPos();
100                        while (itb.hasNext()) {
101                                SliceIterator itaa = a.getSliceIteratorFromAxes(apos, notachoice);
102                                SliceIterator itba = b.getSliceIteratorFromAxes(bpos, notbchoice);
103                                double sum = 0.0;
104                                double com = 0.0;
105                                while (itaa.hasNext() && itba.hasNext()) {
106                                        final double y = a.getElementDoubleAbs(itaa.index) * b.getElementDoubleAbs(itba.index) - com;
107                                        final double t = sum + y;
108                                        com = (t - sum) - y;
109                                        sum = t;
110                                }
111                                data.setObjectAbs(l++, sum);
112                        }
113                }
114
115                return data;
116        }
117
118        /**
119         * Calculate the tensor dot product over given axes. This is the sum of products of elements selected
120         * from the given axes in each dataset
121         * @param a
122         * @param b
123         * @param axisa axis dimensions in a to sum over (can be -ve)
124         * @param axisb axis dimensions in b to sum over (can be -ve)
125         * @return tensor dot product
126         */
127        public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int[] axisa, final int[] axisb) {
128                if (axisa.length != axisb.length) {
129                        throw new IllegalArgumentException("Numbers of summing axes must be same");
130                }
131                final int[] ashape = a.getShapeRef();
132                final int[] bshape = b.getShapeRef();
133                final int arank = ashape.length;
134                final int brank = bshape.length;
135                final int[] aaxes = new int[axisa.length];
136                final int[] baxes = new int[axisa.length];
137                for (int i = 0; i < axisa.length; i++) {
138                        int n;
139
140                        n = axisa[i];
141                        if (n < 0) n += arank;
142                        if (n < 0 || n >= arank)
143                                throw new IllegalArgumentException("Summing axis outside valid rank of 1st dataset");
144                        aaxes[i] = n;
145
146                        n = axisb[i];
147                        if (n < 0) n += brank;
148                        if (n < 0 || n >= brank)
149                                throw new IllegalArgumentException("Summing axis outside valid rank of 2nd dataset");
150                        baxes[i] = n;
151
152                        if (ashape[aaxes[i]] != bshape[n])
153                                throw new IllegalArgumentException("Summing axes do not have matching lengths");
154                }
155
156                final boolean[] achoice = new boolean[arank];
157                final boolean[] bchoice = new boolean[brank];
158                Arrays.fill(achoice, true);
159                Arrays.fill(bchoice, true);
160                for (int i = 0; i < aaxes.length; i++) { // flag which axes to iterate over
161                        achoice[aaxes[i]] = false;
162                        bchoice[baxes[i]] = false;
163                }
164
165                int drank = arank + brank - 2*aaxes.length;
166                int[] dshape = new int[drank];
167                int d = 0;
168                for (int i = 0; i < arank; i++) {
169                        if (achoice[i])
170                                dshape[d++] = ashape[i];
171                }
172                for (int i = 0; i < brank; i++) {
173                        if (bchoice[i])
174                                dshape[d++] = bshape[i];
175                }
176                int dtype = DTypeUtils.getBestDType(a.getDType(), b.getDType());
177                @SuppressWarnings("deprecation")
178                Dataset data = DatasetFactory.zeros(dshape, dtype);
179
180                SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice);
181                int l = 0;
182                final int[] apos = ita.getPos();
183                while (ita.hasNext()) {
184                        SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice);
185                        final int[] bpos = itb.getPos();
186                        while (itb.hasNext()) {
187                                double sum = 0.0;
188                                double com = 0.0;
189                                apos[aaxes[aaxes.length - 1]] = -1;
190                                bpos[baxes[aaxes.length - 1]] = -1;
191                                while (true) { // step through summing axes
192                                        int e = aaxes.length - 1;
193                                        for (; e >= 0; e--) {
194                                                int ai = aaxes[e];
195                                                int bi = baxes[e];
196
197                                                apos[ai]++;
198                                                bpos[bi]++;
199                                                if (apos[ai] == ashape[ai]) {
200                                                        apos[ai] = 0;
201                                                        bpos[bi] = 0;
202                                                } else
203                                                        break;
204                                        }
205                                        if (e == -1) break;
206                                        final double y = a.getDouble(apos) * b.getDouble(bpos) - com;
207                                        final double t = sum + y;
208                                        com = (t - sum) - y;
209                                        sum = t;
210                                }
211                                data.setObjectAbs(l++, sum);
212                        }
213                }
214
215                return data;
216        }
217
218        /**
219         * Calculate the dot product of two datasets. When <b>b</b> is a 1D dataset, the sum product over
220         * the last axis of <b>a</b> and <b>b</b> is returned. Where <b>a</b> is also a 1D dataset, a zero-rank dataset
221         * is returned. If <b>b</b> is 2D or higher, its second-to-last axis is used
222         * @param a
223         * @param b
224         * @return dot product
225         */
226        public static Dataset dotProduct(Dataset a, Dataset b) {
227                if (b.getRank() < 2)
228                        return tensorDotProduct(a, b, -1, 0);
229                return tensorDotProduct(a, b, -1, -2);
230        }
231
232        /**
233         * Calculate the outer product of two datasets
234         * @param a
235         * @param b
236         * @return outer product
237         */
238        public static Dataset outerProduct(Dataset a, Dataset b) {
239                int[] as = a.getShapeRef();
240                int[] bs = b.getShapeRef();
241                int rank = as.length + bs.length;
242                int[] shape = new int[rank];
243                for (int i = 0; i < as.length; i++) {
244                        shape[i] = as[i];
245                }
246                for (int i = 0; i < bs.length; i++) {
247                        shape[as.length + i] = bs[i];
248                }
249                int isa = a.getElementsPerItem();
250                int isb = b.getElementsPerItem();
251                if (isa != 1 || isb != 1) {
252                        throw new UnsupportedOperationException("Compound datasets not supported");
253                }
254                @SuppressWarnings("deprecation")
255                Dataset o = DatasetFactory.zeros(shape, DTypeUtils.getBestDType(a.getDType(), b.getDType()));
256
257                IndexIterator ita = a.getIterator();
258                IndexIterator itb = b.getIterator();
259                int j = 0;
260                while (ita.hasNext()) {
261                        double va = a.getElementDoubleAbs(ita.index);
262                        while (itb.hasNext()) {
263                                o.setObjectAbs(j++, va * b.getElementDoubleAbs(itb.index));
264                        }
265                        itb.reset();
266                }
267                return o;
268        }
269
270        /**
271         * Calculate the cross product of two datasets. Datasets must be broadcastable and
272         * possess last dimensions of length 2 or 3
273         * @param a
274         * @param b
275         * @return cross product
276         */
277        public static Dataset crossProduct(Dataset a, Dataset b) {
278                return crossProduct(a, b, -1, -1, -1);
279        }
280
281        /**
282         * Calculate the cross product of two datasets. Datasets must be broadcastable and
283         * possess dimensions of length 2 or 3. The axis parameters can be negative to indicate
284         * dimensions from the end of their shapes
285         * @param a
286         * @param b
287         * @param axisA dimension to be used a vector (must have length of 2 or 3)
288         * @param axisB dimension to be used a vector (must have length of 2 or 3)
289         * @param axisC dimension to assign as cross-product
290         * @return cross product
291         */
292        public static Dataset crossProduct(Dataset a, Dataset b, int axisA, int axisB, int axisC) {
293                final int rankA = a.getRank();
294                final int rankB = b.getRank();
295                if (rankA == 0 || rankB == 0) {
296                        throw new IllegalArgumentException("Datasets must have one or more dimensions");
297                }
298                if (axisA < 0) {
299                        axisA += rankA;
300                }
301                if (axisA < 0 || axisA >= rankA) {
302                        throw new IllegalArgumentException("Axis A argument exceeds rank");
303                }
304                if (axisB < 0) {
305                        axisB += rankB;
306                }
307                if (axisB < 0 || axisB >= rankB) {
308                        throw new IllegalArgumentException("Axis B argument exceeds rank");
309                }
310
311                final int[] shapeA = a.getShape();
312                final int[] shapeB = b.getShape();
313                int la = shapeA[axisA];
314                int lb = shapeB[axisB];
315                if (Math.min(la,  lb) < 2 || Math.max(la, lb) > 3) {
316                        throw new IllegalArgumentException("Chosen dimension of A & B must be 2 or 3");
317                }
318
319                if (Math.max(la,  lb) == 2) {
320                        return crossProduct2D(a, b, axisA, axisB);
321                }
322
323                return crossProduct3D(a, b, axisA, axisB, axisC);
324        }
325
326        private static int[] removeAxisFromShape(int[] shape, int axis) {
327                int[] s = new int[shape.length - 1];
328                int i = 0;
329                int j = 0;
330                while (i < axis) {
331                        s[j++] = shape[i++];
332                }
333                i++;
334                while (i < shape.length) {
335                        s[j++] = shape[i++];
336                }
337                return s;
338        }
339
340        // assume axes is in increasing order
341        private static int[] removeAxesFromShape(int[] shape, int... axes) {
342                int n = axes.length;
343                int[] s = new int[shape.length - n];
344                int i = 0;
345                int j = 0;
346                for (int k = 0; k < n; k++) {
347                        int a = axes[k];
348                        while (i < a) {
349                                s[j++] = shape[i++];
350                        }
351                        i++;
352                }
353                while (i < shape.length) {
354                        s[j++] = shape[i++];
355                }
356                return s;
357        }
358
359        private static int[] addAxisToShape(int[] shape, int axis, int length) {
360                int[] s = new int[shape.length + 1];
361                int i = 0;
362                int j = 0;
363                while (i < axis) {
364                        s[j++] = shape[i++];
365                }
366                s[j++] = length;
367                while (i < shape.length) {
368                        s[j++] = shape[i++];
369                }
370                return s;
371        }
372
373        // assume axes is in increasing order
374        private static int[] addAxesToShape(int[] shape, int[] axes, int[] lengths) {
375                int n = axes.length;
376                if (lengths.length != n) {
377                        throw new IllegalArgumentException("Axes and lengths arrays must be same size");
378                }
379                int[] s = new int[shape.length + n];
380                int i = 0;
381                int j = 0;
382                for (int k = 0; k < n; k++) {
383                        int a = axes[k];
384                        while (i < a) {
385                                s[j++] = shape[i++];
386                        }
387                        s[j++] = lengths[k];
388                }
389                while (i < shape.length) {
390                        s[j++] = shape[i++];
391                }
392                return s;
393        }
394
395        private static Dataset crossProduct2D(Dataset a, Dataset b, int axisA, int axisB) {
396                // need to broadcast and omit given axes
397                int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA);
398                int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB);
399
400                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB);
401
402                int[] maxShape = fullShapes.get(0);
403                @SuppressWarnings("deprecation")
404                Dataset c = DatasetFactory.zeros(maxShape, DTypeUtils.getBestDType(a.getDType(), b.getDType()));
405
406                PositionIterator ita = a.getPositionIterator(axisA);
407                PositionIterator itb = b.getPositionIterator(axisB);
408                IndexIterator itc = c.getIterator();
409
410                final int[] pa = ita.getPos();
411                final int[] pb = itb.getPos();
412                while (itc.hasNext()) {
413                        if (!ita.hasNext()) // TODO use broadcasting...
414                                ita.reset();
415                        if (!itb.hasNext())
416                                itb.reset();
417                        pa[axisA] = 0;
418                        pb[axisB] = 1;
419                        double cv = a.getDouble(pa) * b.getDouble(pb);
420                        pa[axisA] = 1;
421                        pb[axisB] = 0;
422                        cv -= a.getDouble(pa) * b.getDouble(pb);
423
424                        c.setObjectAbs(itc.index, cv);
425                }
426                return c;
427        }
428
429        private static Dataset crossProduct3D(Dataset a, Dataset b, int axisA, int axisB, int axisC) {
430                int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA);
431                int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB);
432
433                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB);
434
435                int[] maxShape = fullShapes.get(0);
436                int rankC = maxShape.length + 1;
437                if (axisC < 0) {
438                        axisC += rankC;
439                }
440                if (axisC < 0 || axisC >= rankC) {
441                        throw new IllegalArgumentException("Axis C argument exceeds rank");
442                }
443                maxShape = addAxisToShape(maxShape, axisC, 3);
444                @SuppressWarnings("deprecation")
445                Dataset c = DatasetFactory.zeros(maxShape, DTypeUtils.getBestDType(a.getDType(), b.getDType()));
446
447                PositionIterator ita = a.getPositionIterator(axisA);
448                PositionIterator itb = b.getPositionIterator(axisB);
449                PositionIterator itc = c.getPositionIterator(axisC);
450
451                final int[] pa = ita.getPos();
452                final int[] pb = itb.getPos();
453                final int[] pc = itc.getPos();
454                final int la = a.getShapeRef()[axisA];
455                final int lb = b.getShapeRef()[axisB];
456
457                if (la == 2) {
458                        while (itc.hasNext()) {
459                                if (!ita.hasNext()) // TODO use broadcasting...
460                                        ita.reset();
461                                if (!itb.hasNext())
462                                        itb.reset();
463                                double cv;
464                                pa[axisA] = 1;
465                                pb[axisB] = 2;
466                                cv = a.getDouble(pa) * b.getDouble(pb);
467                                pc[axisC] = 0;
468                                c.set(cv, pc);
469
470                                pa[axisA] = 0;
471                                pb[axisB] = 2;
472                                cv = -a.getDouble(pa) * b.getDouble(pb);
473                                pc[axisC] = 1;
474                                c.set(cv, pc);
475
476                                pa[axisA] = 0;
477                                pb[axisB] = 1;
478                                cv = a.getDouble(pa) * b.getDouble(pb);
479                                pa[axisA] = 1;
480                                pb[axisB] = 0;
481                                cv -= a.getDouble(pa) * b.getDouble(pb);
482                                pc[axisC] = 2;
483                                c.set(cv, pc);
484                        }
485                } else if (lb == 2) {
486                        while (itc.hasNext()) {
487                                if (!ita.hasNext()) // TODO use broadcasting...
488                                        ita.reset();
489                                if (!itb.hasNext())
490                                        itb.reset();
491                                double cv;
492                                pa[axisA] = 2;
493                                pb[axisB] = 1;
494                                cv = -a.getDouble(pa) * b.getDouble(pb);
495                                pc[axisC] = 0;
496                                c.set(cv, pc);
497
498                                pa[axisA] = 2;
499                                pb[axisB] = 0;
500                                cv = a.getDouble(pa) * b.getDouble(pb);
501                                pc[axisC] = 1;
502                                c.set(cv, pc);
503
504                                pa[axisA] = 0;
505                                pb[axisB] = 1;
506                                cv = a.getDouble(pa) * b.getDouble(pb);
507                                pa[axisA] = 1;
508                                pb[axisB] = 0;
509                                cv -= a.getDouble(pa) * b.getDouble(pb);
510                                pc[axisC] = 2;
511                                c.set(cv, pc);
512                        }
513                        
514                } else {
515                        while (itc.hasNext()) {
516                                if (!ita.hasNext()) // TODO use broadcasting...
517                                        ita.reset();
518                                if (!itb.hasNext())
519                                        itb.reset();
520                                double cv;
521                                pa[axisA] = 1;
522                                pb[axisB] = 2;
523                                cv = a.getDouble(pa) * b.getDouble(pb);
524                                pa[axisA] = 2;
525                                pb[axisB] = 1;
526                                cv -= a.getDouble(pa) * b.getDouble(pb);
527                                pc[axisC] = 0;
528                                c.set(cv, pc);
529
530                                pa[axisA] = 2;
531                                pb[axisB] = 0;
532                                cv = a.getDouble(pa) * b.getDouble(pb);
533                                pa[axisA] = 0;
534                                pb[axisB] = 2;
535                                cv -= a.getDouble(pa) * b.getDouble(pb);
536                                pc[axisC] = 1;
537                                c.set(cv, pc);
538
539                                pa[axisA] = 0;
540                                pb[axisB] = 1;
541                                cv = a.getDouble(pa) * b.getDouble(pb);
542                                pa[axisA] = 1;
543                                pb[axisB] = 0;
544                                cv -= a.getDouble(pa) * b.getDouble(pb);
545                                pc[axisC] = 2;
546                                c.set(cv, pc);
547                        }
548                }
549                return c;
550        }
551
552        /**
553         * Raise dataset to given power by matrix multiplication
554         * @param a
555         * @param n power
556         * @return a ** n
557         */
558        public static Dataset power(Dataset a, int n) {
559                if (n < 0) {
560                        LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
561                        return createDataset(lud.getSolver().getInverse().power(-n));
562                }
563                Dataset p = createDataset(createRealMatrix(a).power(n));
564                if (!a.hasFloatingPointElements())
565                        return p.cast(a.getDType());
566                return p;
567        }
568
569        /**
570         * Create the Kronecker product as defined by 
571         * kron[k0,...,kN] = a[i0,...,iN] * b[j0,...,jN]
572         * where kn = sn * in + jn for n = 0...N and s is shape of b
573         * @param a
574         * @param b
575         * @return Kronecker product of a and b
576         */
577        public static Dataset kroneckerProduct(Dataset a, Dataset b) {
578                if (a.getElementsPerItem() != 1 || b.getElementsPerItem() != 1) {
579                        throw new UnsupportedOperationException("Compound datasets (including complex ones) are not currently supported");
580                }
581                int ar = a.getRank();
582                int br = b.getRank();
583                int[] aShape;
584                int[] bShape;
585                aShape = a.getShapeRef();
586                bShape = b.getShapeRef();
587                int r = ar;
588                // pre-pad if ranks are not same
589                if (ar < br) {
590                        r = br;
591                        int[] shape = new int[br];
592                        int j = 0;
593                        for (int i = ar; i < br; i++) {
594                                shape[j++] = 1;
595                        }
596                        int i = 0;
597                        while (j < br) {
598                                shape[j++] = aShape[i++];
599                        }
600                        a = a.reshape(shape);
601                        aShape = shape;
602                } else if (ar > br) {
603                        int[] shape = new int[ar];
604                        int j = 0;
605                        for (int i = br; i < ar; i++) {
606                                shape[j++] = 1;
607                        }
608                        int i = 0;
609                        while (j < ar) {
610                                shape[j++] = bShape[i++];
611                        }
612                        b = b.reshape(shape);
613                        bShape = shape;
614                }
615
616                int[] nShape = new int[r];
617                for (int i = 0; i < r; i++) {
618                        nShape[i] = aShape[i] * bShape[i];
619                }
620                @SuppressWarnings("deprecation")
621                Dataset kron = DatasetFactory.zeros(nShape, DTypeUtils.getBestDType(a.getDType(), b.getDType()));
622                IndexIterator ita = a.getIterator(true);
623                IndexIterator itb = b.getIterator(true);
624                int[] pa = ita.getPos();
625                int[] pb = itb.getPos();
626                int[] off = new int[1];
627                int[] stride = AbstractDataset.createStrides(1, nShape, null, 0, off);
628                if (kron.getDType() == Dataset.INT64) {
629                        while (ita.hasNext()) {
630                                long av = a.getElementLongAbs(ita.index);
631
632                                int ka = 0; 
633                                for (int i = 0; i < r; i++) {
634                                        ka += stride[i] * bShape[i] * pa[i];
635                                }
636                                itb.reset();
637                                while (itb.hasNext()) {
638                                        long bv = b.getElementLongAbs(itb.index);
639                                        int kb = ka;
640                                        for (int i = 0; i < r; i++) {
641                                                kb += stride[i] * pb[i];
642                                        }
643                                        kron.setObjectAbs(kb, av * bv);
644                                }
645                        }
646                } else {
647                        while (ita.hasNext()) {
648                                double av = a.getElementDoubleAbs(ita.index);
649
650                                int ka = 0; 
651                                for (int i = 0; i < r; i++) {
652                                        ka += stride[i] * bShape[i] * pa[i];
653                                }
654                                itb.reset();
655                                while (itb.hasNext()) {
656                                        double bv = b.getElementLongAbs(itb.index);
657                                        int kb = ka;
658                                        for (int i = 0; i < r; i++) {
659                                                kb += stride[i] * pb[i];
660                                        }
661                                        kron.setObjectAbs(kb, av * bv);
662                                }
663                        }
664                }
665
666                return kron;
667        }
668
669        /**
670         * Calculate trace of dataset - sum of values over 1st axis and 2nd axis
671         * @param a
672         * @return trace of dataset
673         */
674        public static Dataset trace(Dataset a) {
675                return trace(a, 0, 0, 1);
676        }
677
678        /**
679         * Calculate trace of dataset - sum of values over axis1 and axis2 where axis2 is offset
680         * @param a
681         * @param offset
682         * @param axis1
683         * @param axis2
684         * @return trace of dataset
685         */
686        public static Dataset trace(Dataset a, int offset, int axis1, int axis2) {
687                int[] shape = a.getShapeRef();
688                int[] axes = new int[] { a.checkAxis(axis1), a.checkAxis(axis2) };
689                Arrays.sort(axes);
690                int is = a.getElementsPerItem();
691                @SuppressWarnings("deprecation")
692                Dataset trace = DatasetFactory.zeros(is, removeAxesFromShape(shape, axes), a.getDType());
693
694                int am = axes[0];
695                int mmax = shape[am];
696                int an = axes[1];
697                int nmax = shape[an];
698                PositionIterator it = new PositionIterator(shape, axes);
699                int[] pos = it.getPos();
700                int i = 0;
701                int mmin;
702                int nmin;
703                if (offset >= 0) {
704                        mmin = 0;
705                        nmin = offset;
706                } else {
707                        mmin = -offset;
708                        nmin = 0;
709                }
710                if (is == 1) {
711                        if (a.getDType() == Dataset.INT64) {
712                                while (it.hasNext()) {
713                                        int m = mmin;
714                                        int n = nmin;
715                                        long s = 0;
716                                        while (m < mmax && n < nmax) {
717                                                pos[am] = m++;
718                                                pos[an] = n++;
719                                                s += a.getLong(pos);
720                                        }
721                                        trace.setObjectAbs(i++, s);
722                                }
723                        } else {
724                                while (it.hasNext()) {
725                                        int m = mmin;
726                                        int n = nmin;
727                                        double s = 0;
728                                        while (m < mmax && n < nmax) {
729                                                pos[am] = m++;
730                                                pos[an] = n++;
731                                                s += a.getDouble(pos);
732                                        }
733                                        trace.setObjectAbs(i++, s);
734                                }
735                        }
736                } else {
737                        AbstractCompoundDataset ca = (AbstractCompoundDataset) a;
738                        if (ca instanceof CompoundLongDataset) {
739                                long[] t = new long[is];
740                                long[] s = new long[is];
741                                while (it.hasNext()) {
742                                        int m = mmin;
743                                        int n = nmin;
744                                        Arrays.fill(s, 0);
745                                        while (m < mmax && n < nmax) {
746                                                pos[am] = m++;
747                                                pos[an] = n++;
748                                                ((CompoundLongDataset)ca).getAbs(ca.get1DIndex(pos), t);
749                                                for (int k = 0; k < is; k++) {
750                                                        s[k] += t[k];
751                                                }
752                                        }
753                                        trace.setObjectAbs(i++, s);
754                                }
755                        } else {
756                                double[] t = new double[is];
757                                double[] s = new double[is];
758                                while (it.hasNext()) {
759                                        int m = mmin;
760                                        int n = nmin;
761                                        Arrays.fill(s, 0);
762                                        while (m < mmax && n < nmax) {
763                                                pos[am] = m++;
764                                                pos[an] = n++;
765                                                ca.getDoubleArray(t, pos);
766                                                for (int k = 0; k < is; k++) {
767                                                        s[k] += t[k];
768                                                }
769                                        }
770                                        trace.setObjectAbs(i++, s);
771                                }
772                        }
773                }
774
775                return trace;
776        }
777
778        /**
779         * Order value for norm
780         */
781        public enum NormOrder {
782                /**
783                 * 2-norm for vectors and Frobenius for matrices
784                 */
785                DEFAULT,
786                /**
787                 * Frobenius (not allowed for vectors)
788                 */
789                FROBENIUS,
790                /**
791                 * Zero-order (not allowed for matrices)
792                 */
793                ZERO,
794                /**
795                 * Positive infinity
796                 */
797                POS_INFINITY,
798                /**
799                 * Negative infinity
800                 */
801                NEG_INFINITY;
802        }
803
804        /**
805         * @param a
806         * @return norm of dataset
807         */
808        public static double norm(Dataset a) {
809                return norm(a, NormOrder.DEFAULT);
810        }
811
812        /**
813         * @param a
814         * @param order
815         * @return norm of dataset
816         */
817        public static double norm(Dataset a, NormOrder order) {
818                int r = a.getRank();
819                if (r == 1) {
820                        return vectorNorm(a, order);
821                } else if (r == 2) {
822                        return matrixNorm(a, order);
823                }
824                throw new IllegalArgumentException("Rank of dataset must be one or two");
825        }
826
827        private static double vectorNorm(Dataset a, NormOrder order) {
828                double n;
829                IndexIterator it;
830                switch (order) {
831                case FROBENIUS:
832                        throw new IllegalArgumentException("Not allowed for vectors");
833                case NEG_INFINITY:
834                case POS_INFINITY:
835                        it = a.getIterator();
836                        if (order == NormOrder.POS_INFINITY) {
837                                n = Double.NEGATIVE_INFINITY;
838                                if (a.isComplex()) {
839                                        while (it.hasNext()) {
840                                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
841                                                n = Math.max(n, v);
842                                        }
843                                } else {
844                                        while (it.hasNext()) {
845                                                double v = Math.abs(a.getElementDoubleAbs(it.index));
846                                                n = Math.max(n, v);
847                                        }
848                                }
849                        } else {
850                                n = Double.POSITIVE_INFINITY;
851                                if (a.isComplex()) {
852                                        while (it.hasNext()) {
853                                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
854                                                n = Math.min(n, v);
855                                        }
856                                } else {
857                                        while (it.hasNext()) {
858                                                double v = Math.abs(a.getElementDoubleAbs(it.index));
859                                                n = Math.min(n, v);
860                                        }
861                                }
862                        }
863                        break;
864                case ZERO:
865                        it = a.getIterator();
866                        n = 0;
867                        if (a.isComplex()) {
868                                while (it.hasNext()) {
869                                        if (!((Complex) a.getObjectAbs(it.index)).equals(Complex.ZERO))
870                                                n++;
871                                }
872                        } else {
873                                while (it.hasNext()) {
874                                        if (a.getElementBooleanAbs(it.index))
875                                                n++;
876                                }
877                        }
878                        
879                        break;
880                default:
881                        n = vectorNorm(a, 2);
882                        break;
883                }
884                return n;
885        }
886
887        private static double matrixNorm(Dataset a, NormOrder order) {
888                double n;
889                IndexIterator it;
890                switch (order) {
891                case NEG_INFINITY:
892                case POS_INFINITY:
893                        n = maxMinMatrixNorm(a, 1, order == NormOrder.POS_INFINITY);
894                        break;
895                case ZERO:
896                        throw new IllegalArgumentException("Not allowed for matrices");
897                default:
898                case FROBENIUS:
899                        it = a.getIterator();
900                        n = 0;
901                        if (a.isComplex()) {
902                                while (it.hasNext()) {
903                                        double v = ((Complex) a.getObjectAbs(it.index)).abs();
904                                        n += v*v;
905                                }
906                        } else {
907                                while (it.hasNext()) {
908                                        double v = a.getElementDoubleAbs(it.index);
909                                        n += v*v;
910                                }
911                        }
912                        n = Math.sqrt(n);
913                        break;
914                }
915                return n;
916        }
917
918        /**
919         * @param a
920         * @param p
921         * @return p-norm of dataset
922         */
923        public static double norm(Dataset a, final double p) {
924                if (p == 0) {
925                        return norm(a, NormOrder.ZERO);
926                }
927                int r = a.getRank();
928                if (r == 1) {
929                        return vectorNorm(a, p);
930                } else if (r == 2) {
931                        return matrixNorm(a, p);
932                }
933                throw new IllegalArgumentException("Rank of dataset must be one or two");
934        }
935
936        private static double vectorNorm(Dataset a, final double p) {
937                IndexIterator it = a.getIterator();
938                double n = 0;
939                if (a.isComplex()) {
940                        while (it.hasNext()) {
941                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
942                                if (p == 2) {
943                                        v *= v;
944                                } else if (p != 1) {
945                                        v = Math.pow(v, p);
946                                }
947                                n += v;
948                        }
949                } else {
950                        while (it.hasNext()) {
951                                double v = a.getElementDoubleAbs(it.index);
952                                if (p == 1) {
953                                        v = Math.abs(v);
954                                } else if (p == 2) {
955                                        v *= v;
956                                } else {
957                                        v = Math.pow(Math.abs(v), p);
958                                }
959                                n += v;
960                        }
961                }
962                return Math.pow(n, 1./p);
963        }
964
965        private static double matrixNorm(Dataset a, final double p) {
966                double n;
967                if (Math.abs(p) == 1) {
968                        n = maxMinMatrixNorm(a, 0, p > 0);
969                } else if (Math.abs(p) == 2) {
970                        double[] s = calcSingularValues(a);
971                        n = p > 0 ? s[0] : s[s.length - 1];
972                } else {
973                        throw new IllegalArgumentException("Order not allowed");
974                }
975
976                return n;
977        }
978
979        private static double maxMinMatrixNorm(Dataset a, int d, boolean max) {
980                double n;
981                IndexIterator it;
982                int[] pos;
983                int l;
984                it = a.getPositionIterator(d);
985                pos = it.getPos();
986                l = a.getShapeRef()[d];
987                if (max) {
988                        n = Double.NEGATIVE_INFINITY;
989                        if (a.isComplex()) {
990                                while (it.hasNext()) {
991                                        double v = ((Complex) a.getObject(pos)).abs();
992                                        for (int i = 1; i < l; i++) {
993                                                pos[d] = i;
994                                                v += ((Complex) a.getObject(pos)).abs();
995                                        }
996                                        pos[d] = 0;
997                                        n = Math.max(n, v);
998                                }
999                        } else {
1000                                while (it.hasNext()) {
1001                                        double v = Math.abs(a.getDouble(pos));
1002                                        for (int i = 1; i < l; i++) {
1003                                                pos[d] = i;
1004                                                v += Math.abs(a.getDouble(pos));
1005                                        }
1006                                        pos[d] = 0;
1007                                        n = Math.max(n, v);
1008                                }
1009                        }
1010                } else {
1011                        n = Double.POSITIVE_INFINITY;
1012                        if (a.isComplex()) {
1013                                while (it.hasNext()) {
1014                                        double v = ((Complex) a.getObject(pos)).abs();
1015                                        for (int i = 1; i < l; i++) {
1016                                                pos[d] = i;
1017                                                v += ((Complex) a.getObject(pos)).abs();
1018                                        }
1019                                        pos[d] = 0;
1020                                        n = Math.min(n, v);
1021                                }
1022                        } else {
1023                                while (it.hasNext()) {
1024                                        double v = Math.abs(a.getDouble(pos));
1025                                        for (int i = 1; i < l; i++) {
1026                                                pos[d] = i;
1027                                                v += Math.abs(a.getDouble(pos));
1028                                        }
1029                                        pos[d] = 0;
1030                                        n = Math.min(n, v);
1031                                }
1032                        }
1033                }
1034                return n;
1035        }
1036
1037        /**
1038         * @param a
1039         * @return array of singular values
1040         */
1041        public static double[] calcSingularValues(Dataset a) {
1042                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1043                return svd.getSingularValues();
1044        }
1045
1046
1047        /**
1048         * Calculate singular value decomposition A = U S V^T
1049         * @param a
1050         * @return array of U - orthogonal matrix, s - singular values vector, V - orthogonal matrix
1051         */
1052        public static Dataset[] calcSingularValueDecomposition(Dataset a) {
1053                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1054                return new Dataset[] {createDataset(svd.getU()), DatasetFactory.createFromObject(svd.getSingularValues()),
1055                                createDataset(svd.getV())};
1056        }
1057
1058        /**
1059         * Calculate (Moore-Penrose) pseudo-inverse
1060         * @param a
1061         * @return pseudo-inverse
1062         */
1063        public static Dataset calcPseudoInverse(Dataset a) {
1064                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1065                return createDataset(svd.getSolver().getInverse());
1066        }
1067
1068        /**
1069         * Calculate matrix rank by singular value decomposition method
1070         * @param a
1071         * @return effective numerical rank of matrix
1072         */
1073        public static int calcMatrixRank(Dataset a) {
1074                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1075                return svd.getRank();
1076        }
1077
1078        /**
1079         * Calculate condition number of matrix by singular value decomposition method
1080         * @param a
1081         * @return condition number
1082         */
1083        public static double calcConditionNumber(Dataset a) {
1084                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1085                return svd.getConditionNumber();
1086        }
1087
1088        /**
1089         * @param a
1090         * @return determinant of dataset
1091         */
1092        public static double calcDeterminant(Dataset a) {
1093                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1094                return evd.getDeterminant();
1095        }
1096
1097        /**
1098         * @param a
1099         * @return dataset of eigenvalues (can be double or complex double)
1100         */
1101        public static Dataset calcEigenvalues(Dataset a) {
1102                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1103                double[] rev = evd.getRealEigenvalues();
1104
1105                if (evd.hasComplexEigenvalues()) {
1106                        double[] iev = evd.getImagEigenvalues();
1107                        return DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev);
1108                }
1109                return DatasetFactory.createFromObject(rev);
1110        }
1111
1112        /**
1113         * Calculate eigen-decomposition A = V D V^T
1114         * @param a
1115         * @return array of D eigenvalues (can be double or complex double) and V eigenvectors
1116         */
1117        public static Dataset[] calcEigenDecomposition(Dataset a) {
1118                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1119                Dataset[] results = new Dataset[2];
1120
1121                double[] rev = evd.getRealEigenvalues();
1122                if (evd.hasComplexEigenvalues()) {
1123                        double[] iev = evd.getImagEigenvalues();
1124                        results[0] = DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev);
1125                } else {
1126                        results[0] = DatasetFactory.createFromObject(rev);
1127                }
1128                results[1] = createDataset(evd.getV());
1129                return results;
1130        }
1131
1132        /**
1133         * Calculate QR decomposition A = Q R
1134         * @param a
1135         * @return array of Q and R
1136         */
1137        public static Dataset[] calcQRDecomposition(Dataset a) {
1138                QRDecomposition qrd = new QRDecomposition(createRealMatrix(a));
1139                return new Dataset[] {createDataset(qrd.getQT()).getTransposedView(), createDataset(qrd.getR())};
1140        }
1141
1142        /**
1143         * Calculate LU decomposition A = P^-1 L U
1144         * @param a
1145         * @return array of L, U and P
1146         */
1147        public static Dataset[] calcLUDecomposition(Dataset a) {
1148                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1149                return new Dataset[] {createDataset(lud.getL()), createDataset(lud.getU()),
1150                                createDataset(lud.getP())};
1151        }
1152
1153        /**
1154         * Calculate inverse of square dataset
1155         * @param a
1156         * @return inverse
1157         */
1158        public static Dataset calcInverse(Dataset a) {
1159                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1160                return createDataset(lud.getSolver().getInverse());
1161        }
1162
1163        /**
1164         * Solve linear matrix equation A x = v
1165         * @param a
1166         * @param v
1167         * @return x
1168         */
1169        public static Dataset solve(Dataset a, Dataset v) {
1170                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1171                if (v.getRank() == 1) {
1172                        RealVector x = createRealVector(v);
1173                        return createDataset(lud.getSolver().solve(x));
1174                }
1175                RealMatrix x = createRealMatrix(v);
1176                return createDataset(lud.getSolver().solve(x));
1177        }
1178
1179        
1180        /**
1181         * Solve least squares matrix equation A x = v by SVD
1182         * @param a
1183         * @param v
1184         * @return x
1185         */
1186        public static Dataset solveSVD(Dataset a, Dataset v) {
1187                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1188                if (v.getRank() == 1) {
1189                        RealVector x = createRealVector(v);
1190                        return createDataset(svd.getSolver().solve(x));
1191                }
1192                RealMatrix x = createRealMatrix(v);
1193                return createDataset(svd.getSolver().solve(x));
1194        }
1195        
1196        /**
1197         * Calculate Cholesky decomposition A = L L^T
1198         * @param a
1199         * @return L
1200         */
1201        public static Dataset calcCholeskyDecomposition(Dataset a) {
1202                CholeskyDecomposition cd = new CholeskyDecomposition(createRealMatrix(a));
1203                return createDataset(cd.getL());
1204        }
1205
1206        /**
1207         * Calculation A x = v by conjugate gradient method with the stopping criterion being
1208         * that the estimated residual r = v - A x satisfies ||r|| < ||v|| with maximum of 100 iterations
1209         * @param a
1210         * @param v
1211         * @return solution of A^-1 v by conjugate gradient method
1212         */
1213        public static Dataset calcConjugateGradient(Dataset a, Dataset v) {
1214                return calcConjugateGradient(a, v, 100, 1);
1215        }
1216
1217        /**
1218         * Calculation A x = v by conjugate gradient method with the stopping criterion being
1219         * that the estimated residual r = v - A x satisfies ||r|| < delta ||v||
1220         * @param a
1221         * @param v
1222         * @param maxIterations
1223         * @param delta parameter used by stopping criterion
1224         * @return solution of A^-1 v by conjugate gradient method
1225         */
1226        public static Dataset calcConjugateGradient(Dataset a, Dataset v, int maxIterations, double delta) {
1227                ConjugateGradient cg = new ConjugateGradient(maxIterations, delta, false);
1228                return createDataset(cg.solve((RealLinearOperator) createRealMatrix(a), createRealVector(v)));
1229        }
1230
1231        private static RealMatrix createRealMatrix(Dataset a) {
1232                if (a.getRank() != 2) {
1233                        throw new IllegalArgumentException("Dataset must be rank 2");
1234                }
1235                int[] shape = a.getShapeRef();
1236                IndexIterator it = a.getIterator(true);
1237                int[] pos = it.getPos();
1238                RealMatrix m = MatrixUtils.createRealMatrix(shape[0], shape[1]);
1239                while (it.hasNext()) {
1240                        m.setEntry(pos[0], pos[1], a.getElementDoubleAbs(it.index));
1241                }
1242                return m;
1243        }
1244
1245        private static RealVector createRealVector(Dataset a) {
1246                if (a.getRank() != 1) {
1247                        throw new IllegalArgumentException("Dataset must be rank 1");
1248                }
1249                int size = a.getSize();
1250                IndexIterator it = a.getIterator(true);
1251                int[] pos = it.getPos();
1252                RealVector m = new ArrayRealVector(size);
1253                while (it.hasNext()) {
1254                        m.setEntry(pos[0], a.getElementDoubleAbs(it.index));
1255                }
1256                return m;
1257        }
1258
1259        private static Dataset createDataset(RealVector v) {
1260                DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, v.getDimension());
1261                int size = r.getSize();
1262                if (v instanceof ArrayRealVector) {
1263                        double[] data = ((ArrayRealVector) v).getDataRef();
1264                        for (int i = 0; i < size; i++) {
1265                                r.setAbs(i, data[i]);
1266                        }
1267                } else {
1268                        for (int i = 0; i < size; i++) {
1269                                r.setAbs(i, v.getEntry(i));
1270                        }
1271                }
1272                return r;
1273        }
1274
1275        private static Dataset createDataset(RealMatrix m) {
1276                DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, m.getRowDimension(), m.getColumnDimension());
1277                if (m instanceof Array2DRowRealMatrix) {
1278                        double[][] data = ((Array2DRowRealMatrix) m).getDataRef();
1279                        IndexIterator it = r.getIterator(true);
1280                        int[] pos = it.getPos();
1281                        while (it.hasNext()) {
1282                                r.setAbs(it.index, data[pos[0]][pos[1]]);
1283                        }
1284                } else {
1285                        IndexIterator it = r.getIterator(true);
1286                        int[] pos = it.getPos();
1287                        while (it.hasNext()) {
1288                                r.setAbs(it.index, m.getEntry(pos[0], pos[1]));
1289                        }
1290                }
1291                return r;
1292        }
1293}