001/*******************************************************************************
002 * Copyright (c) 2016 Diamond Light Source Ltd. and others.
003 * All rights reserved. This program and the accompanying materials
004 * are made available under the terms of the Eclipse Public License v1.0
005 * which accompanies this distribution, and is available at
006 * http://www.eclipse.org/legal/epl-v10.html
007 *
008 * Contributors:
009 *     Diamond Light Source Ltd - initial API and implementation
010 *******************************************************************************/
011package org.eclipse.january.dataset;
012
013import java.lang.reflect.Array;
014import java.util.ArrayList;
015import java.util.List;
016
017public class ShapeUtils {
018
019        private ShapeUtils() {
020        }
021
022        /**
023         * Calculate total number of items in given shape
024         * @param shape
025         * @return size
026         */
027        public static long calcLongSize(final int[] shape) {
028                if (shape == null) { // special case of null-shaped
029                        return 0;
030                }
031
032                final int rank = shape.length;
033                if (rank == 0) { // special case of zero-rank shape 
034                        return 1;
035                }
036        
037                double dsize = 1.0;
038                for (int i = 0; i < rank; i++) {
039                        // make sure the indexes isn't zero or negative
040                        if (shape[i] == 0) {
041                                return 0;
042                        } else if (shape[i] < 0) {
043                                throw new IllegalArgumentException(String.format(
044                                                "The %d-th is %d which is not allowed as it is negative", i, shape[i]));
045                        }
046        
047                        dsize *= shape[i];
048                }
049        
050                // check to see if the size is larger than an integer, i.e. we can't allocate it
051                if (dsize > Long.MAX_VALUE) {
052                        throw new IllegalArgumentException("Size of the dataset is too large to allocate");
053                }
054                return (long) dsize;
055        }
056
057        /**
058         * Calculate total number of items in given shape
059         * @param shape
060         * @return size
061         */
062        public static int calcSize(final int[] shape) {
063                long lsize = calcLongSize(shape);
064        
065                // check to see if the size is larger than an integer, i.e. we can't allocate it
066                if (lsize > Integer.MAX_VALUE) {
067                        throw new IllegalArgumentException("Size of the dataset is too large to allocate");
068                }
069                return (int) lsize;
070        }
071
072        /**
073         * Check if shapes are broadcast compatible
074         * 
075         * @param ashape
076         * @param bshape
077         * @return true if they are compatible
078         */
079        public static boolean areShapesBroadcastCompatible(final int[] ashape, final int[] bshape) {
080                if (ashape == null || bshape == null) {
081                        return ashape == bshape;
082                }
083
084                if (ashape.length < bshape.length) {
085                        return areShapesBroadcastCompatible(bshape, ashape);
086                }
087        
088                for (int a = ashape.length - bshape.length, b = 0; a < ashape.length && b < bshape.length; a++, b++) {
089                        if (ashape[a] != bshape[b] && ashape[a] != 1 && bshape[b] != 1) {
090                                return false;
091                        }
092                }
093        
094                return true;
095        }
096
097        /**
098         * Check if shapes are compatible, ignoring extra axes of length 1
099         * 
100         * @param ashape
101         * @param bshape
102         * @return true if they are compatible
103         */
104        public static boolean areShapesCompatible(final int[] ashape, final int[] bshape) {
105                if (ashape == null || bshape == null) {
106                        return ashape == bshape;
107                }
108
109                List<Integer> alist = new ArrayList<Integer>();
110        
111                for (int a : ashape) {
112                        if (a > 1) alist.add(a);
113                }
114        
115                final int imax = alist.size();
116                int i = 0;
117                for (int b : bshape) {
118                        if (b == 1)
119                                continue;
120                        if (i >= imax || b != alist.get(i++))
121                                return false;
122                }
123        
124                return i == imax;
125        }
126
127        /**
128         * Check if shapes are compatible but skip axis
129         * 
130         * @param ashape
131         * @param bshape
132         * @param axis
133         * @return true if they are compatible
134         */
135        public static boolean areShapesCompatible(final int[] ashape, final int[] bshape, final int axis) {
136                if (ashape == null || bshape == null) {
137                        return ashape == bshape;
138                }
139
140                if (ashape.length != bshape.length) {
141                        return false;
142                }
143        
144                final int rank = ashape.length;
145                for (int i = 0; i < rank; i++) {
146                        if (i != axis && ashape[i] != bshape[i]) {
147                                return false;
148                        }
149                }
150                return true;
151        }
152
153        /**
154         * Remove dimensions of 1 in given shape - from both ends only, if true
155         * 
156         * @param oshape
157         * @param onlyFromEnds
158         * @return newly squeezed shape (or original if unsqueezed)
159         */
160        public static int[] squeezeShape(final int[] oshape, boolean onlyFromEnds) {
161                int unitDims = 0;
162                int rank = oshape.length;
163                int start = 0;
164        
165                if (onlyFromEnds) {
166                        int i = rank - 1;
167                        for (; i >= 0; i--) {
168                                if (oshape[i] == 1) {
169                                        unitDims++;
170                                } else {
171                                        break;
172                                }
173                        }
174                        for (int j = 0; j <= i; j++) {
175                                if (oshape[j] == 1) {
176                                        unitDims++;
177                                } else {
178                                        start = j;
179                                        break;
180                                }
181                        }
182                } else {
183                        for (int i = 0; i < rank; i++) {
184                                if (oshape[i] == 1) {
185                                        unitDims++;
186                                }
187                        }
188                }
189        
190                if (unitDims == 0) {
191                        return oshape;
192                }
193        
194                int[] newDims = new int[rank - unitDims];
195                if (unitDims == rank)
196                        return newDims; // zero-rank dataset
197        
198                if (onlyFromEnds) {
199                        rank = newDims.length;
200                        for (int i = 0; i < rank; i++) {
201                                newDims[i] = oshape[i+start];
202                        }
203                } else {
204                        int j = 0;
205                        for (int i = 0; i < rank; i++) {
206                                if (oshape[i] > 1) {
207                                        newDims[j++] = oshape[i];
208                                        if (j >= newDims.length)
209                                                break;
210                                }
211                        }
212                }
213        
214                return newDims;
215        }
216
217        /**
218         * Remove dimension of 1 in given shape
219         * 
220         * @param oshape
221         * @param axis
222         * @return newly squeezed shape
223         */
224        public static int[] squeezeShape(final int[] oshape, int axis) {
225                if (oshape == null) {
226                        return null;
227                }
228
229                final int rank = oshape.length;
230                if (rank == 0) {
231                        return new int[0];
232                }
233                if (axis < 0) {
234                        axis += rank;
235                }
236                if (axis < 0 || axis >= rank) {
237                        throw new IllegalArgumentException("Axis argument is outside allowed range");
238                }
239                int[] nshape = new int[rank-1];
240                for (int i = 0; i < axis; i++) {
241                        nshape[i] = oshape[i];
242                }
243                for (int i = axis+1; i < rank; i++) {
244                        nshape[i-1] = oshape[i];
245                }
246                return nshape;
247        }
248
249        /**
250         * Get shape from object (array or list supported)
251         * @param obj
252         * @return shape can be null if obj is null
253         */
254        public static int[] getShapeFromObject(final Object obj) {
255                if (obj == null) {
256                        return null;
257                }
258
259                ArrayList<Integer> lshape = new ArrayList<Integer>();
260                getShapeFromObj(lshape, obj, 0);
261
262                final int rank = lshape.size();
263                final int[] shape = new int[rank];
264                for (int i = 0; i < rank; i++) {
265                        shape[i] = lshape.get(i);
266                }
267        
268                return shape;
269        }
270
271        /**
272         * Get shape from object
273         * @param ldims
274         * @param obj
275         * @param depth
276         * @return true if there is a possibility of differing lengths
277         */
278        private static boolean getShapeFromObj(final ArrayList<Integer> ldims, Object obj, int depth) {
279                if (obj == null)
280                        return true;
281        
282                if (obj instanceof List<?>) {
283                        List<?> jl = (List<?>) obj;
284                        int l = jl.size();
285                        updateShape(ldims, depth, l);
286                        for (int i = 0; i < l; i++) {
287                                Object lo = jl.get(i);
288                                if (!getShapeFromObj(ldims, lo, depth + 1)) {
289                                        break;
290                                }
291                        }
292                        return true;
293                }
294                Class<? extends Object> ca = obj.getClass().getComponentType();
295                if (ca != null) {
296                        final int l = Array.getLength(obj);
297                        updateShape(ldims, depth, l);
298                        if (DTypeUtils.isClassSupportedAsElement(ca)) {
299                                return true;
300                        }
301                        for (int i = 0; i < l; i++) {
302                                Object lo = Array.get(obj, i);
303                                if (!getShapeFromObj(ldims, lo, depth + 1)) {
304                                        break;
305                                }
306                        }
307                        return true;
308                } else if (obj instanceof IDataset) {
309                        int[] s = ((IDataset) obj).getShape();
310                        for (int i = 0; i < s.length; i++) {
311                                updateShape(ldims, depth++, s[i]);
312                        }
313                        return true;
314                } else {
315                        return false; // not an array of any type
316                }
317        }
318
319        private static void updateShape(final ArrayList<Integer> ldims, final int depth, final int l) {
320                if (depth >= ldims.size()) {
321                        ldims.add(l);
322                } else if (l > ldims.get(depth)) {
323                        ldims.set(depth, l);
324                }
325        }
326
327        /**
328         * Get n-D position from given index
329         * @param n index
330         * @param shape
331         * @return n-D position
332         */
333        public static int[] getNDPositionFromShape(int n, int[] shape) {
334                if (shape == null) {
335                        return null;
336                }
337
338                int rank = shape.length;
339                if (rank == 0) {
340                        return new int[0];
341                }
342
343                if (rank == 1) {
344                        return new int[] { n };
345                }
346
347                int[] output = new int[rank];
348                for (rank--; rank > 0; rank--) {
349                        output[rank] = n % shape[rank];
350                        n /= shape[rank];
351                }
352                output[0] = n;
353        
354                return output;
355        }
356
357        /**
358         * Get flattened view index of given position 
359         * @param shape
360         * @param pos
361         *            the integer array specifying the n-D position
362         * @return the index on the flattened dataset
363         */
364        public static int getFlat1DIndex(final int[] shape, final int[] pos) {
365                final int imax = pos.length;
366                if (imax == 0) {
367                        return 0;
368                }
369        
370                return AbstractDataset.get1DIndexFromShape(shape, pos);
371        }
372
373        /**
374         * This function takes a dataset and checks its shape against another dataset. If they are both of the same size,
375         * then this returns with no error, if there is a problem, then an error is thrown.
376         * 
377         * @param g
378         *            The first dataset to be compared
379         * @param h
380         *            The second dataset to be compared
381         * @throws IllegalArgumentException
382         *             This will be thrown if there is a problem with the compatibility
383         */
384        public static void checkCompatibility(final ILazyDataset g, final ILazyDataset h) throws IllegalArgumentException {
385                if (!areShapesCompatible(g.getShape(), h.getShape())) {
386                        throw new IllegalArgumentException("Shapes do not match");
387                }
388        }
389
390        
391}