diff --git a/datafu-pig/src/main/java/datafu/pig/sampling/UniformRandomSample.java b/datafu-pig/src/main/java/datafu/pig/sampling/UniformRandomSample.java
new file mode 100644
index 00000000..d4a0dbd3
--- /dev/null
+++ b/datafu-pig/src/main/java/datafu/pig/sampling/UniformRandomSample.java
@@ -0,0 +1,409 @@
+/*
+ * Copyright 2018 IICOLL
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this software
+ * and associated documentation files (the "Software"), to deal in the Software without
+ * restriction, including without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies or
+ * substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
+ * BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+ * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+package datafu.pig.sampling;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.SortedSet;
+import java.util.TreeSet;
+
+import org.apache.commons.math.random.RandomDataImpl;
+import org.apache.pig.AlgebraicEvalFunc;
+import org.apache.pig.EvalFunc;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.DataType;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.data.TupleFactory;
+import org.apache.pig.impl.logicalLayer.FrontendException;
+import org.apache.pig.impl.logicalLayer.schema.Schema;
+
+/**
+ * Scalable uniform random sampling.
+ *
+ *
+ * This UDF implements a uniform random sampling algorithm
+ *
+ *
+ *
+ * It takes a bag of n items and either a fraction (p) or an exact number (k) of
+ * the items to be selected as the input, and returns a bag of k or ceil(p*n)
+ * items uniformly sampled.
+ *
+ *
+ *
+ * DEFINE URS datafu.pig.sampling.UniformRandomSample('p | k & n');
+ *
+ * item = LOAD 'input' AS (x:);
+ * sampled = FOREACH (GROUP items ALL) GENERATE FLATTEN(URS(items));
+ *
+ *
+ */
+
+public class UniformRandomSample extends AlgebraicEvalFunc {
+ /**
+ * Prefix for the output bag name.
+ */
+ public static final String OUTPUT_BAG_NAME_PREFIX = "URS";
+
+ private static final TupleFactory _TUPLE_FACTORY = TupleFactory.getInstance();
+ private static final BagFactory _BAG_FACTORY = BagFactory.getInstance();
+
+ public UniformRandomSample(String ps) {
+ super(ps);
+ // all parameters calculations should be done in nested classes constructors
+ // to allow multiple instantiations through calls like:
+ // DEFINE URS1 datafu.pig.sampling.UniformRandomSample('$p');
+ // DEFINE URS2 datafu.pig.sampling.UniformRandomSample('$k, $n');
+ // data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+ // sampled = FOREACH (GROUP data ALL) GENERATE URS1(data) as sample_1, URS2(data) AS sample_2;
+ }
+
+ @Override
+ public String getInitial() {
+ return Initial.class.getName();
+ }
+
+ @Override
+ public String getIntermed() {
+ return Intermed.class.getName();
+ }
+
+ @Override
+ public String getFinal() {
+ return Final.class.getName();
+ }
+
+ @Override
+ public Schema outputSchema(Schema input) {
+ try {
+ Schema.FieldSchema inputFieldSchema = input.getField(0);
+
+ if (inputFieldSchema.type != DataType.BAG) {
+ throw new RuntimeException("Expected a BAG as input");
+ }
+ return new Schema(new Schema.FieldSchema(super.getSchemaName(OUTPUT_BAG_NAME_PREFIX, input),
+ inputFieldSchema.schema, DataType.BAG));
+ } catch (FrontendException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ protected static Tuple getPNK(String pi){
+ Tuple pnk = _TUPLE_FACTORY.newTuple();
+ long n, k;
+ double p;
+ String[] ss = pi.split(",");
+ if (pi.startsWith("0.") || pi.startsWith(".")){
+ try {
+ p = Double.parseDouble(ss[0].trim());
+ pnk.append(p);
+ } catch (NumberFormatException e) {
+ throw new RuntimeException("p should be a number, got NumberFormatException:"+pi);
+ }
+ } else {
+ if (ss.length == 2){
+ try {
+ k = Long.parseLong(ss[0].trim(), 10);
+ n = Long.parseLong(ss[1].trim(), 10);
+ p = (double) k / n;
+ pnk.append(p);
+ pnk.append(k);
+ } catch (NumberFormatException e) {
+ throw new RuntimeException("k and n should be numbers, got NumberFormatException:"+pi);
+ }
+ } else {
+ throw new RuntimeException("2 parameters are required k and n, got:"+pi);
+ }
+ }
+ return pnk;
+ }
+
+ /**
+ * 1st mapped data processing step, can't be skipped
+ *
+ */
+ static public class Initial extends EvalFunc {
+
+ public Initial() {}
+
+ private Tuple pnk;
+
+ public Initial(String pi){
+ pnk = getPNK(pi);
+ }
+
+ private static RandomDataImpl _RNG = new RandomDataImpl();
+ synchronized private static int nextInt(int n) {
+ return _RNG.nextInt(0, n);
+ }
+
+ @Override
+ public Tuple exec(Tuple input) throws IOException {
+ DataBag items = (DataBag) input.get(0);
+
+ DataBag selected = _BAG_FACTORY.newDefaultBag();
+ Tuple extra = _TUPLE_FACTORY.newTuple();
+ Tuple tu;
+ Tuple output = _TUPLE_FACTORY.newTuple();
+ Double p = (Double) pnk.get(0);
+
+ if (items.size() == 0 || p == 0d) {
+ return _TUPLE_FACTORY.newTuple();
+ } else if (items.size() == 1) {
+ tu = items.iterator().next();
+ extra = tu;
+ output.append(1);
+ output.append(selected);
+ output.append(extra);
+
+ return output;
+ }
+
+ // the set should not exceed int, if initial set is bigger than max_int,
+ // split into sub-sets
+ if (items.size() > Integer.MAX_VALUE){
+ throw new IndexOutOfBoundsException("bag size is above int maximum");
+ }
+ int numItems = (int) items.size();
+
+ int k_down = (int) Math.floor(p * numItems);
+ int k_up = (int) Math.ceil(p * numItems);
+ if (k_up == 0) return _TUPLE_FACTORY.newTuple();
+
+ int x;
+ numItems--;
+ SortedSet nums = new TreeSet();
+ int kk = k_up;
+
+ // if we need to return more than a half of input elements
+ // insteaed of addition it make sense to make exclusion
+ // I mean
+ // p <= 0.5
+ // add selected randomly elements to output set
+ // p > 0.5
+ // add all elements except randomly selected
+ // for exclusion if we need an extra element
+ // take the 1st, since at the end eventually no
+ // elements to take from have left
+
+ if (p <= 0.5){
+ k_down = numItems - k_down + 1;
+ k_up = numItems - k_up + 1;
+ kk = k_down;
+ }
+ while (nums.size() < kk){
+ x = nextInt(numItems);
+ nums.add(x);
+ }
+
+ int i=0;
+ int j;
+ Iterator it = nums.iterator();
+ Iterator it2 = items.iterator();
+ tu = it2.next();
+ int ii = it.next();
+
+ if (p > 0.5){
+ while(i == ii){
+ i++;
+ ii = it.next();
+ tu = it2.next();
+ }
+ // add the 1st valid element
+ if (k_down == k_up){
+ selected.add(tu);
+ } else {
+ extra=tu;
+ }
+ tu = it2.next();
+ i++;
+ }
+
+ while (it.hasNext()){
+ if (p <= 0.5){
+ for ( j=i; j {
+
+ public Intermed(){}
+
+ private Tuple pnk;
+
+ public Intermed(String pi){
+ pnk = getPNK(pi);
+ }
+
+ @Override
+ public Tuple exec(Tuple input) throws IOException {
+ DataBag bag = (DataBag) input.get(0);
+ DataBag selected = _BAG_FACTORY.newDefaultBag();
+ DataBag in_extra = _BAG_FACTORY.newDefaultBag();
+ Tuple out_extra = _TUPLE_FACTORY.newTuple();
+ long numItems = 0L;
+ long gotItems = 0L;
+ long required;
+
+ for (Tuple tuple : bag){
+ numItems += ((Number) tuple.get(0)).longValue();
+ gotItems += ((Number)((DataBag) tuple.get(1)).size()).longValue();
+ selected.addAll((DataBag) tuple.get(1));
+ in_extra.add((Tuple) tuple.get(2));
+ }
+
+ Double p = (Double) pnk.get(0);
+ long required_up = (long) Math.ceil(p * numItems);
+ long required_down = (long) Math.floor(p * numItems);
+
+ if (in_extra.size() > 0) {
+ Iterator it = in_extra.iterator();
+ Tuple tu = it.next();
+ if (in_extra.size() == 1 && gotItems < required_down) {
+ selected.add(tu);
+ } else {
+ while (gotItems < required_down && it.hasNext()){
+ selected.add(tu);
+ gotItems++;
+ tu = it.next();
+ }
+ if (tu != null && required_down < required_up) { out_extra = tu; }
+ }
+ }
+
+ Tuple output = _TUPLE_FACTORY.newTuple();
+ output.append(numItems);
+ output.append(selected);
+ output.append(out_extra);
+ return output;
+ }
+ }
+
+
+ /**
+ * this final should be executed as reducer
+ * merges all selected bags into the output
+ * adding extra in case more elements needed
+ */
+ static public class Final extends EvalFunc {
+
+ public Final(){}
+
+ private Tuple pnk;
+
+ public Final(String pi){
+ pnk = getPNK(pi);
+ }
+
+ @Override
+ public DataBag exec(Tuple input) throws IOException {
+ DataBag bag = (DataBag) input.get(0);
+
+ if (bag.size() == 0) { return _BAG_FACTORY.newDefaultBag(); }
+
+ long n_total = 0L;
+
+ DataBag selected = _BAG_FACTORY.newDefaultBag();
+ DataBag extra = _BAG_FACTORY.newDefaultBag();
+
+ Iterator it = bag.iterator();
+ Tuple tuple = it.next();
+ while(it.hasNext()) {
+ n_total += ((Number) tuple.get(0)).longValue();
+ selected.addAll((DataBag) tuple.get(1));
+ extra.add((Tuple) tuple.get(2));
+ tuple = it.next();
+ }
+ n_total += ((Number) tuple.get(0)).longValue();
+ selected.addAll((DataBag) tuple.get(1));
+ extra.add((Tuple) tuple.get(2));
+
+ long s; // final requested sample size
+ if (pnk.size() > 1){
+ s = ((Number) pnk.get(1)).longValue();
+ } else {
+ Double p = (Double) pnk.get(0);
+ s = (long) Math.ceil(p * n_total);
+ }
+
+ it = extra.iterator();
+ while(it.hasNext() && selected.size() < s ) {
+ selected.add(it.next());
+ }
+
+ return selected;
+ }
+ }
+
+}
diff --git a/datafu-pig/src/test/java/datafu/test/pig/sampling/UniformRandomSampleTest.java b/datafu-pig/src/test/java/datafu/test/pig/sampling/UniformRandomSampleTest.java
new file mode 100644
index 00000000..51e82b05
--- /dev/null
+++ b/datafu-pig/src/test/java/datafu/test/pig/sampling/UniformRandomSampleTest.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package datafu.test.pig.sampling;
+
+import org.adrianwalker.multilinestring.Multiline;
+import org.apache.pig.pigunit.PigTest;
+import org.testng.annotations.Test;
+
+import datafu.pig.sampling.UniformRandomSample;
+import datafu.test.pig.PigTests;
+
+/**
+ * Tests for {@link UniformRandomSample}.
+ *
+ */
+public class UniformRandomSampleTest extends PigTests {
+ /**
+ *
+ *
+ * DEFINE URS datafu.pig.sampling.UniformRandomSample('$p');
+ *
+ * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+ *
+ * sampled = FOREACH (GROUP data ALL) GENERATE URS(data) as sample_data;
+ *
+ * sampled = FOREACH sampled GENERATE COUNT(sample_data) AS sample_count;
+ *
+ * STORE sampled INTO 'output';
+ */
+ @Multiline
+ private String uniformRandomSampleFractionTest;
+
+ /**
+ *
+ *
+ * DEFINE URS datafu.pig.sampling.UniformRandomSample('$k, $n');
+ *
+ * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+ *
+ * sampled = FOREACH (GROUP data ALL) GENERATE URS(data) as sample_data;
+ *
+ * sampled = FOREACH sampled GENERATE COUNT(sample_data) AS sample_count;
+ *
+ * STORE sampled INTO 'output';
+ */
+ @Multiline
+ private String uniformRandomSampleIntTest;
+
+ /**
+ * DEFINE URS1 datafu.pig.sampling.UniformRandomSample('$p');
+ *
+ * DEFINE URS2 datafu.pig.sampling.UniformRandomSample('$k, $n');
+ *
+ * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+ *
+ * sampled = FOREACH (GROUP data ALL) GENERATE URS1(data) as sample_1, URS2(data)
+ * AS sample_2;
+ *
+ * sampled = FOREACH sampled GENERATE COUNT(sample_1) AS sample_count_1, COUNT(sample_2)
+ * AS sample_count_2;
+ *
+ * STORE sampled INTO 'output';
+ */
+ @Multiline
+ private String uniformRandomSampleWithTwoCallsTest;
+
+ @Test
+ public void testUniformRandomSample() throws Exception {
+ writeLinesToFile("input",
+ "A1\tB1\t1",
+ "A2\tB1\t4",
+ "A3\tB3\t4",
+ "A4\tB4\t4",
+ "A5\tB1\t4",
+ "A6\tB2\t4",
+ "A7\tB1\t3",
+ "A8\tB1\t1",
+ "A9\tB3\t77",
+ "A10\tB1\t3",
+ "A11\tB2\t3",
+ "A12\tB3\t59",
+ "A13\tB4\t29",
+ "A14\tB1\t4",
+ "A15\tB2\t3",
+ "A16\tB2\t55",
+ "A17\tB3\t1",
+ "A18\tB1\t39",
+ "A19\tB2\t27",
+ "A20\tB3\t85",
+ "A21\tB1\t4",
+ "A22\tB2\t45",
+ "A23\tB3\t92",
+ "A24\tB3\t0",
+ "A25\tB6\t42",
+ "A26\tB5\t1",
+ "A27\tB1\t7",
+ "A28\tB2\t23",
+ "A29\tB2\t1",
+ "A30\tB2\t31",
+ "A31\tB6\t41",
+ "A32\tB7\t52");
+
+ int n = 32;
+ double p = 0.3;
+ int s = (int) Math.ceil(p * n);
+ PigTest test = createPigTestFromString(uniformRandomSampleFractionTest, "p=" + p);
+ test.runScript();
+ assertOutput(test, "sampled", "(" + s + ")");
+
+ int k = 10;
+ PigTest test2 =
+ createPigTestFromString(uniformRandomSampleIntTest, "k=" + k, "n=" + n);
+ test2.runScript();
+ assertOutput(test2, "sampled", "(" + k + ")");
+
+ p = 0.05;
+ k = 15;
+ s = (int) Math.ceil(p * n);
+
+ PigTest testWithTwoCalls =
+ createPigTestFromString(uniformRandomSampleWithTwoCallsTest, "p=" + p, "k=" + k, "n=" + n);
+ testWithTwoCalls.runScript();
+ assertOutput(testWithTwoCalls, "sampled", "(" + s + "," + k + ")");
+ }
+
+ /**
+ * DEFINE URS datafu.pig.sampling.UniformRandomSample('$SAMPLING_FRACTION');
+ *
+ * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+ *
+ * sampled = FOREACH (GROUP data BY A_id) GENERATE group, URS(data) as sample_data;
+ *
+ * sampled = FOREACH sampled GENERATE group, COUNT(sample_data) AS sample_count;
+ *
+ * sampled = ORDER sampled BY group;
+ *
+ * STORE sampled INTO 'output';
+ */
+ @Multiline
+ private String stratifiedSampleTest;
+
+ @Test
+ public void testStratifiedSample() throws Exception {
+ writeLinesToFile("input",
+ "A1\tB1\t1",
+ "A1\tB1\t4",
+ "A1\tB3\t4",
+ "A1\tB4\t4",
+ "A2\tB1\t4",
+ "A2\tB2\t4",
+ "A3\tB1\t3",
+ "A3\tB1\t1",
+ "A3\tB3\t77",
+ "A4\tB1\t3",
+ "A4\tB2\t3",
+ "A4\tB3\t59",
+ "A4\tB4\t29",
+ "A5\tB1\t4",
+ "A6\tB2\t3",
+ "A6\tB2\t55",
+ "A6\tB3\t1",
+ "A7\tB1\t39",
+ "A7\tB2\t27",
+ "A7\tB3\t85",
+ "A8\tB1\t4",
+ "A8\tB2\t45",
+ "A9\tB3\t92",
+ "A9\tB3\t0",
+ "A9\tB6\t42",
+ "A9\tB5\t1",
+ "A10\tB1\t7",
+ "A10\tB2\t23",
+ "A10\tB2\t1",
+ "A10\tB2\t31",
+ "A10\tB6\t41",
+ "A10\tB7\t52");
+
+ double p = 0.5;
+
+ PigTest test =
+ createPigTestFromString(stratifiedSampleTest, "SAMPLING_FRACTION=" + p);
+ test.runScript();
+ assertOutput(test,
+ "sampled",
+ "(A1,2)",
+ "(A10,3)",
+ "(A2,1)",
+ "(A3,2)",
+ "(A4,2)",
+ "(A5,1)",
+ "(A6,2)",
+ "(A7,2)",
+ "(A8,1)",
+ "(A9,2)");
+ }
+
+}