001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016import java.util.List;
017
018/**
019 * Class to run over a single dataset with NumPy broadcasting to promote shapes
020 * which have lower rank and outputs to a second dataset
021 */
022public class SingleInputBroadcastIterator extends IndexIterator {
023        private int[] maxShape;
024        private int[] aShape;
025        private final Dataset aDataset;
026        private final Dataset oDataset;
027        private int[] aStride;
028        private int[] oStride;
029
030        final private int endrank;
031
032        /**
033         * position in dataset
034         */
035        private final int[] pos;
036        private final int[] aDelta;
037        private final int[] oDelta; // this being non-null means output is different from inputs
038        private final int aStep, oStep;
039        private int aMax;
040        private int aStart, oStart;
041        private final boolean outputA;
042
043        /**
044         * Index in array
045         */
046        public int aIndex, oIndex;
047
048        /**
049         * Current value in array
050         */
051        public double aDouble;
052
053        /**
054         * Current value in array
055         */
056        public long aLong;
057
058        private boolean asDouble = true;
059
060        /**
061         * @param a
062         * @param o (can be null for new dataset, or a)
063         */
064        public SingleInputBroadcastIterator(Dataset a, Dataset o) {
065                this(a, o, false);
066        }
067
068        /**
069         * @param a
070         * @param o (can be null for new dataset, or a)
071         * @param createIfNull (by default, can create float or complex datasets)
072         */
073        public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull) {
074                this(a, o, createIfNull, false, true);
075        }
076
077        /**
078         * @param a
079         * @param o (can be null for new dataset, or a)
080         * @param createIfNull
081         * @param allowInteger if true, can create integer datasets
082         * @param allowComplex if true, can create complex datasets
083         */
084        public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) {
085                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef());
086
087                BroadcastUtils.checkItemSize(a, o);
088
089                maxShape = fullShapes.remove(0);
090
091                oStride = null;
092                if (o != null) {
093                        if (!Arrays.equals(maxShape, o.getShapeRef())) {
094                                throw new IllegalArgumentException("Output does not match broadcasted shape");
095                        }
096                        o.setDirty();
097                }
098
099                aShape = fullShapes.remove(0);
100
101                int rank = maxShape.length;
102                endrank = rank - 1;
103
104                aDataset = a.reshape(aShape);
105                aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape);
106                outputA = o == a;
107                if (outputA) {
108                        oStride = aStride;
109                        oDelta = null;
110                        oStep = 0;
111                        oDataset = aDataset;
112                } else if (o != null) {
113                        oStride = BroadcastUtils.createBroadcastStrides(o, maxShape);
114                        oDelta = new int[rank];
115                        oStep = o.getElementsPerItem();
116                        oDataset = o;
117                } else if (createIfNull) {
118                        int is = aDataset.getElementsPerItem();
119                        Class<? extends Dataset> dc = aDataset.getClass();
120                        if (aDataset.isComplex() && !allowComplex) {
121                                is = 1;
122                                dc = InterfaceUtils.getBestFloatInterface(dc);
123                        } else if (!aDataset.hasFloatingPointElements() && !allowInteger) {
124                                dc = InterfaceUtils.getBestFloatInterface(dc);
125                        }
126                        oDataset = DatasetFactory.zeros(is, dc, maxShape);
127                        oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape);
128                        oDelta = new int[rank];
129                        oStep = oDataset.getElementsPerItem();
130                } else {
131                        oDelta = null;
132                        oStep = 0;
133                        oDataset = o;
134                }
135
136                pos = new int[rank];
137                aDelta = new int[rank];
138                aStep = aDataset.getElementsPerItem();
139                for (int j = endrank; j >= 0; j--) {
140                        aDelta[j] = aStride[j] * aShape[j];
141                        if (oDelta != null) {
142                                oDelta[j] = oStride[j] * maxShape[j];
143                        }
144                }
145                aStart = aDataset.getOffset();
146                aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE;
147                oStart = oDelta == null ? 0 : oDataset.getOffset();
148                asDouble = aDataset.hasFloatingPointElements();
149                reset();
150        }
151
152        /**
153         * @return true if output from iterator is double
154         */
155        public boolean isOutputDouble() {
156                return asDouble;
157        }
158
159        /**
160         * Set to output doubles
161         * @param asDouble
162         */
163        public void setOutputDouble(boolean asDouble) {
164                if (this.asDouble != asDouble) {
165                        this.asDouble = asDouble;
166                        storeCurrentValues();
167                }
168        }
169
170        @Override
171        public int[] getShape() {
172                return maxShape;
173        }
174
175        @Override
176        public boolean hasNext() {
177                int j = endrank;
178                int oldA = aIndex;
179                for (; j >= 0; j--) {
180                        pos[j]++;
181                        aIndex += aStride[j];
182                        if (oDelta != null) {
183                                oIndex += oStride[j];
184                        }
185                        if (pos[j] >= maxShape[j]) {
186                                pos[j] = 0;
187                                aIndex -= aDelta[j]; // reset these dimensions
188                                if (oDelta != null) {
189                                        oIndex -= oDelta[j];
190                                }
191                        } else {
192                                break;
193                        }
194                }
195                if (j == -1) {
196                        if (endrank >= 0) {
197                                return false;
198                        }
199                        aIndex += aStep;
200                        if (oDelta != null) {
201                                oIndex += oStep;
202                        }
203                }
204                if (outputA) {
205                        oIndex = aIndex;
206                }
207
208                if (aIndex == aMax) {
209                        return false; // used for zero-rank datasets
210                }
211
212                if (oldA != aIndex) {
213                        if (asDouble) {
214                                aDouble = aDataset.getElementDoubleAbs(aIndex);
215                        } else {
216                                aLong = aDataset.getElementLongAbs(aIndex);
217                        }
218                }
219
220                return true;
221        }
222
223        /**
224         * @return output dataset (can be null)
225         */
226        public Dataset getOutput() {
227                return oDataset;
228        }
229
230        @Override
231        public int[] getPos() {
232                return pos;
233        }
234
235        @Override
236        public void reset() {
237                for (int i = 0; i <= endrank; i++) {
238                        pos[i] = 0;
239                }
240
241                if (endrank >= 0) {
242                        pos[endrank] = -1;
243                        aIndex = aStart - aStride[endrank];
244                        oIndex = oStart - (oStride == null ? 0 : oStride[endrank]);
245                } else {
246                        aIndex = -aStep;
247                        oIndex = -oStep;
248                }
249
250                // for zero-ranked datasets
251                if (aIndex == 0) {
252                        storeCurrentValues();
253                }
254        }
255
256        private void storeCurrentValues() {
257                if (aIndex >= 0) {
258                        if (asDouble) {
259                                aDouble = aDataset.getElementDoubleAbs(aIndex);
260                        } else {
261                                aLong = aDataset.getElementLongAbs(aIndex);
262                        }
263                }
264        }
265}