Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions varopt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package varopt
import (
"container/heap"
"fmt"
"math"
"math/rand"
)

Expand Down Expand Up @@ -48,6 +49,8 @@ type vsample struct {

type largeHeap []vsample

var ErrInvalidWeight = fmt.Errorf("Negative, zero, or NaN weight")

// New returns a new Varopt sampler with given capacity (i.e.,
// reservoir size) and random number generator.
func New(capacity int, rnd *rand.Rand) *Varopt {
Expand All @@ -58,22 +61,24 @@ func New(capacity int, rnd *rand.Rand) *Varopt {
}

// Add considers a new observation for the sample with given weight.
func (s *Varopt) Add(sample Sample, weight float64) {
//
// An error will be returned if the weight is either negative or NaN.
func (s *Varopt) Add(sample Sample, weight float64) error {
individual := vsample{
sample: sample,
weight: weight,
}

if weight <= 0 {
panic(fmt.Sprint("Invalid weight <= 0: ", weight))
if weight <= 0 || math.IsNaN(weight) {
return ErrInvalidWeight
}

s.totalCount++
s.totalWeight += weight

if s.Size() < s.capacity {
heap.Push(&s.L, individual)
return
return nil
}

// the X <- {} step from the paper is not done here,
Expand Down Expand Up @@ -115,6 +120,7 @@ func (s *Varopt) Add(sample Sample, weight float64) {
}
s.T = append(s.T, s.X...)
s.X = s.X[:0]
return nil
}

func (s *Varopt) uniform() float64 {
Expand Down
14 changes: 14 additions & 0 deletions varopt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,17 @@ func testUnbiased(t *testing.T, bbr, bsr float64) {
[][][]varopt.Sample{smallBlocks, bigBlocks},
)
}

func TestInvalidWeight(t *testing.T) {
rnd := rand.New(rand.NewSource(98887))
v := varopt.New(1, rnd)

err := v.Add(nil, math.NaN())
require.Equal(t, err, varopt.ErrInvalidWeight)

err = v.Add(nil, -1)
require.Equal(t, err, varopt.ErrInvalidWeight)

err = v.Add(nil, 0)
require.Equal(t, err, varopt.ErrInvalidWeight)
}