diff --git a/datafu-pig/src/main/java/datafu/pig/sampling/UniformRandomSample.java b/datafu-pig/src/main/java/datafu/pig/sampling/UniformRandomSample.java new file mode 100644 index 00000000..d4a0dbd3 --- /dev/null +++ b/datafu-pig/src/main/java/datafu/pig/sampling/UniformRandomSample.java @@ -0,0 +1,409 @@ +/* + * Copyright 2018 IICOLL + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software + * and associated documentation files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or + * substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package datafu.pig.sampling; + +import java.io.IOException; +import java.util.Comparator; +import java.util.Iterator; +import java.util.SortedSet; +import java.util.TreeSet; + +import org.apache.commons.math.random.RandomDataImpl; +import org.apache.pig.AlgebraicEvalFunc; +import org.apache.pig.EvalFunc; +import org.apache.pig.data.BagFactory; +import org.apache.pig.data.DataBag; +import org.apache.pig.data.DataType; +import org.apache.pig.data.Tuple; +import org.apache.pig.data.TupleFactory; +import org.apache.pig.impl.logicalLayer.FrontendException; +import org.apache.pig.impl.logicalLayer.schema.Schema; + +/** + * Scalable uniform random sampling. + * + *

+ * This UDF implements a uniform random sampling algorithm + *

+ * + *

+ * It takes a bag of n items and either a fraction (p) or an exact number (k) of + * the items to be selected as the input, and returns a bag of k or ceil(p*n) + * items uniformly sampled. + *

+ * + *
+ * DEFINE URS datafu.pig.sampling.UniformRandomSample('p | k & n');
+ *
+ * item    = LOAD 'input' AS (x:);
+ * sampled = FOREACH (GROUP items ALL) GENERATE FLATTEN(URS(items));
+ * 
+ * + */ + +public class UniformRandomSample extends AlgebraicEvalFunc { + /** + * Prefix for the output bag name. + */ + public static final String OUTPUT_BAG_NAME_PREFIX = "URS"; + + private static final TupleFactory _TUPLE_FACTORY = TupleFactory.getInstance(); + private static final BagFactory _BAG_FACTORY = BagFactory.getInstance(); + + public UniformRandomSample(String ps) { + super(ps); + // all parameters calculations should be done in nested classes constructors + // to allow multiple instantiations through calls like: + // DEFINE URS1 datafu.pig.sampling.UniformRandomSample('$p'); + // DEFINE URS2 datafu.pig.sampling.UniformRandomSample('$k, $n'); + // data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int); + // sampled = FOREACH (GROUP data ALL) GENERATE URS1(data) as sample_1, URS2(data) AS sample_2; + } + + @Override + public String getInitial() { + return Initial.class.getName(); + } + + @Override + public String getIntermed() { + return Intermed.class.getName(); + } + + @Override + public String getFinal() { + return Final.class.getName(); + } + + @Override + public Schema outputSchema(Schema input) { + try { + Schema.FieldSchema inputFieldSchema = input.getField(0); + + if (inputFieldSchema.type != DataType.BAG) { + throw new RuntimeException("Expected a BAG as input"); + } + return new Schema(new Schema.FieldSchema(super.getSchemaName(OUTPUT_BAG_NAME_PREFIX, input), + inputFieldSchema.schema, DataType.BAG)); + } catch (FrontendException e) { + throw new RuntimeException(e); + } + } + + protected static Tuple getPNK(String pi){ + Tuple pnk = _TUPLE_FACTORY.newTuple(); + long n, k; + double p; + String[] ss = pi.split(","); + if (pi.startsWith("0.") || pi.startsWith(".")){ + try { + p = Double.parseDouble(ss[0].trim()); + pnk.append(p); + } catch (NumberFormatException e) { + throw new RuntimeException("p should be a number, got NumberFormatException:"+pi); + } + } else { + if (ss.length == 2){ + try { + k = Long.parseLong(ss[0].trim(), 10); + n = Long.parseLong(ss[1].trim(), 10); + p = (double) k / n; + pnk.append(p); + pnk.append(k); + } catch (NumberFormatException e) { + throw new RuntimeException("k and n should be numbers, got NumberFormatException:"+pi); + } + } else { + throw new RuntimeException("2 parameters are required k and n, got:"+pi); + } + } + return pnk; + } + + /** + * 1st mapped data processing step, can't be skipped + * + */ + static public class Initial extends EvalFunc { + + public Initial() {} + + private Tuple pnk; + + public Initial(String pi){ + pnk = getPNK(pi); + } + + private static RandomDataImpl _RNG = new RandomDataImpl(); + synchronized private static int nextInt(int n) { + return _RNG.nextInt(0, n); + } + + @Override + public Tuple exec(Tuple input) throws IOException { + DataBag items = (DataBag) input.get(0); + + DataBag selected = _BAG_FACTORY.newDefaultBag(); + Tuple extra = _TUPLE_FACTORY.newTuple(); + Tuple tu; + Tuple output = _TUPLE_FACTORY.newTuple(); + Double p = (Double) pnk.get(0); + + if (items.size() == 0 || p == 0d) { + return _TUPLE_FACTORY.newTuple(); + } else if (items.size() == 1) { + tu = items.iterator().next(); + extra = tu; + output.append(1); + output.append(selected); + output.append(extra); + + return output; + } + + // the set should not exceed int, if initial set is bigger than max_int, + // split into sub-sets + if (items.size() > Integer.MAX_VALUE){ + throw new IndexOutOfBoundsException("bag size is above int maximum"); + } + int numItems = (int) items.size(); + + int k_down = (int) Math.floor(p * numItems); + int k_up = (int) Math.ceil(p * numItems); + if (k_up == 0) return _TUPLE_FACTORY.newTuple(); + + int x; + numItems--; + SortedSet nums = new TreeSet(); + int kk = k_up; + + // if we need to return more than a half of input elements + // insteaed of addition it make sense to make exclusion + // I mean + // p <= 0.5 + // add selected randomly elements to output set + // p > 0.5 + // add all elements except randomly selected + // for exclusion if we need an extra element + // take the 1st, since at the end eventually no + // elements to take from have left + + if (p <= 0.5){ + k_down = numItems - k_down + 1; + k_up = numItems - k_up + 1; + kk = k_down; + } + while (nums.size() < kk){ + x = nextInt(numItems); + nums.add(x); + } + + int i=0; + int j; + Iterator it = nums.iterator(); + Iterator it2 = items.iterator(); + tu = it2.next(); + int ii = it.next(); + + if (p > 0.5){ + while(i == ii){ + i++; + ii = it.next(); + tu = it2.next(); + } + // add the 1st valid element + if (k_down == k_up){ + selected.add(tu); + } else { + extra=tu; + } + tu = it2.next(); + i++; + } + + while (it.hasNext()){ + if (p <= 0.5){ + for ( j=i; j { + + public Intermed(){} + + private Tuple pnk; + + public Intermed(String pi){ + pnk = getPNK(pi); + } + + @Override + public Tuple exec(Tuple input) throws IOException { + DataBag bag = (DataBag) input.get(0); + DataBag selected = _BAG_FACTORY.newDefaultBag(); + DataBag in_extra = _BAG_FACTORY.newDefaultBag(); + Tuple out_extra = _TUPLE_FACTORY.newTuple(); + long numItems = 0L; + long gotItems = 0L; + long required; + + for (Tuple tuple : bag){ + numItems += ((Number) tuple.get(0)).longValue(); + gotItems += ((Number)((DataBag) tuple.get(1)).size()).longValue(); + selected.addAll((DataBag) tuple.get(1)); + in_extra.add((Tuple) tuple.get(2)); + } + + Double p = (Double) pnk.get(0); + long required_up = (long) Math.ceil(p * numItems); + long required_down = (long) Math.floor(p * numItems); + + if (in_extra.size() > 0) { + Iterator it = in_extra.iterator(); + Tuple tu = it.next(); + if (in_extra.size() == 1 && gotItems < required_down) { + selected.add(tu); + } else { + while (gotItems < required_down && it.hasNext()){ + selected.add(tu); + gotItems++; + tu = it.next(); + } + if (tu != null && required_down < required_up) { out_extra = tu; } + } + } + + Tuple output = _TUPLE_FACTORY.newTuple(); + output.append(numItems); + output.append(selected); + output.append(out_extra); + return output; + } + } + + + /** + * this final should be executed as reducer + * merges all selected bags into the output + * adding extra in case more elements needed + */ + static public class Final extends EvalFunc { + + public Final(){} + + private Tuple pnk; + + public Final(String pi){ + pnk = getPNK(pi); + } + + @Override + public DataBag exec(Tuple input) throws IOException { + DataBag bag = (DataBag) input.get(0); + + if (bag.size() == 0) { return _BAG_FACTORY.newDefaultBag(); } + + long n_total = 0L; + + DataBag selected = _BAG_FACTORY.newDefaultBag(); + DataBag extra = _BAG_FACTORY.newDefaultBag(); + + Iterator it = bag.iterator(); + Tuple tuple = it.next(); + while(it.hasNext()) { + n_total += ((Number) tuple.get(0)).longValue(); + selected.addAll((DataBag) tuple.get(1)); + extra.add((Tuple) tuple.get(2)); + tuple = it.next(); + } + n_total += ((Number) tuple.get(0)).longValue(); + selected.addAll((DataBag) tuple.get(1)); + extra.add((Tuple) tuple.get(2)); + + long s; // final requested sample size + if (pnk.size() > 1){ + s = ((Number) pnk.get(1)).longValue(); + } else { + Double p = (Double) pnk.get(0); + s = (long) Math.ceil(p * n_total); + } + + it = extra.iterator(); + while(it.hasNext() && selected.size() < s ) { + selected.add(it.next()); + } + + return selected; + } + } + +} diff --git a/datafu-pig/src/test/java/datafu/test/pig/sampling/UniformRandomSampleTest.java b/datafu-pig/src/test/java/datafu/test/pig/sampling/UniformRandomSampleTest.java new file mode 100644 index 00000000..51e82b05 --- /dev/null +++ b/datafu-pig/src/test/java/datafu/test/pig/sampling/UniformRandomSampleTest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package datafu.test.pig.sampling; + +import org.adrianwalker.multilinestring.Multiline; +import org.apache.pig.pigunit.PigTest; +import org.testng.annotations.Test; + +import datafu.pig.sampling.UniformRandomSample; +import datafu.test.pig.PigTests; + +/** + * Tests for {@link UniformRandomSample}. + * + */ +public class UniformRandomSampleTest extends PigTests { + /** + * + * + * DEFINE URS datafu.pig.sampling.UniformRandomSample('$p'); + * + * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int); + * + * sampled = FOREACH (GROUP data ALL) GENERATE URS(data) as sample_data; + * + * sampled = FOREACH sampled GENERATE COUNT(sample_data) AS sample_count; + * + * STORE sampled INTO 'output'; + */ + @Multiline + private String uniformRandomSampleFractionTest; + + /** + * + * + * DEFINE URS datafu.pig.sampling.UniformRandomSample('$k, $n'); + * + * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int); + * + * sampled = FOREACH (GROUP data ALL) GENERATE URS(data) as sample_data; + * + * sampled = FOREACH sampled GENERATE COUNT(sample_data) AS sample_count; + * + * STORE sampled INTO 'output'; + */ + @Multiline + private String uniformRandomSampleIntTest; + + /** + * DEFINE URS1 datafu.pig.sampling.UniformRandomSample('$p'); + * + * DEFINE URS2 datafu.pig.sampling.UniformRandomSample('$k, $n'); + * + * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int); + * + * sampled = FOREACH (GROUP data ALL) GENERATE URS1(data) as sample_1, URS2(data) + * AS sample_2; + * + * sampled = FOREACH sampled GENERATE COUNT(sample_1) AS sample_count_1, COUNT(sample_2) + * AS sample_count_2; + * + * STORE sampled INTO 'output'; + */ + @Multiline + private String uniformRandomSampleWithTwoCallsTest; + + @Test + public void testUniformRandomSample() throws Exception { + writeLinesToFile("input", + "A1\tB1\t1", + "A2\tB1\t4", + "A3\tB3\t4", + "A4\tB4\t4", + "A5\tB1\t4", + "A6\tB2\t4", + "A7\tB1\t3", + "A8\tB1\t1", + "A9\tB3\t77", + "A10\tB1\t3", + "A11\tB2\t3", + "A12\tB3\t59", + "A13\tB4\t29", + "A14\tB1\t4", + "A15\tB2\t3", + "A16\tB2\t55", + "A17\tB3\t1", + "A18\tB1\t39", + "A19\tB2\t27", + "A20\tB3\t85", + "A21\tB1\t4", + "A22\tB2\t45", + "A23\tB3\t92", + "A24\tB3\t0", + "A25\tB6\t42", + "A26\tB5\t1", + "A27\tB1\t7", + "A28\tB2\t23", + "A29\tB2\t1", + "A30\tB2\t31", + "A31\tB6\t41", + "A32\tB7\t52"); + + int n = 32; + double p = 0.3; + int s = (int) Math.ceil(p * n); + PigTest test = createPigTestFromString(uniformRandomSampleFractionTest, "p=" + p); + test.runScript(); + assertOutput(test, "sampled", "(" + s + ")"); + + int k = 10; + PigTest test2 = + createPigTestFromString(uniformRandomSampleIntTest, "k=" + k, "n=" + n); + test2.runScript(); + assertOutput(test2, "sampled", "(" + k + ")"); + + p = 0.05; + k = 15; + s = (int) Math.ceil(p * n); + + PigTest testWithTwoCalls = + createPigTestFromString(uniformRandomSampleWithTwoCallsTest, "p=" + p, "k=" + k, "n=" + n); + testWithTwoCalls.runScript(); + assertOutput(testWithTwoCalls, "sampled", "(" + s + "," + k + ")"); + } + + /** + * DEFINE URS datafu.pig.sampling.UniformRandomSample('$SAMPLING_FRACTION'); + * + * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int); + * + * sampled = FOREACH (GROUP data BY A_id) GENERATE group, URS(data) as sample_data; + * + * sampled = FOREACH sampled GENERATE group, COUNT(sample_data) AS sample_count; + * + * sampled = ORDER sampled BY group; + * + * STORE sampled INTO 'output'; + */ + @Multiline + private String stratifiedSampleTest; + + @Test + public void testStratifiedSample() throws Exception { + writeLinesToFile("input", + "A1\tB1\t1", + "A1\tB1\t4", + "A1\tB3\t4", + "A1\tB4\t4", + "A2\tB1\t4", + "A2\tB2\t4", + "A3\tB1\t3", + "A3\tB1\t1", + "A3\tB3\t77", + "A4\tB1\t3", + "A4\tB2\t3", + "A4\tB3\t59", + "A4\tB4\t29", + "A5\tB1\t4", + "A6\tB2\t3", + "A6\tB2\t55", + "A6\tB3\t1", + "A7\tB1\t39", + "A7\tB2\t27", + "A7\tB3\t85", + "A8\tB1\t4", + "A8\tB2\t45", + "A9\tB3\t92", + "A9\tB3\t0", + "A9\tB6\t42", + "A9\tB5\t1", + "A10\tB1\t7", + "A10\tB2\t23", + "A10\tB2\t1", + "A10\tB2\t31", + "A10\tB6\t41", + "A10\tB7\t52"); + + double p = 0.5; + + PigTest test = + createPigTestFromString(stratifiedSampleTest, "SAMPLING_FRACTION=" + p); + test.runScript(); + assertOutput(test, + "sampled", + "(A1,2)", + "(A10,3)", + "(A2,1)", + "(A3,2)", + "(A4,2)", + "(A5,1)", + "(A6,2)", + "(A7,2)", + "(A8,1)", + "(A9,2)"); + } + +}