Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mutable union and difference operations on sets #415

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/new_sets_doc.md
Expand Up @@ -246,6 +246,29 @@ Returns the union of several sets.
The set union of all sets in `*args`.


<a id="sets.mutable_union"></a>

## sets.mutable_union

<pre>
sets.mutable_union(<a href="#sets.mutable_union-a">a</a>, <a href="#sets.mutable_union-b">b</a>)
</pre>

Modify set `a` adding elements from `b` to it.

**PARAMETERS**


| Name | Description | Default Value |
| :------------- | :------------- | :------------- |
| <a id="sets.mutable_union-a"></a>a | A set, as returned by <code>sets.make()</code>. | none |
| <a id="sets.mutable_union-b"></a>b | A set, as returned by <code>sets.make()</code>. | none |

**RETURNS**

The set `a` with all elements appearing in `b` added to it.


<a id="sets.difference"></a>

## sets.difference
Expand All @@ -269,6 +292,29 @@ Returns the elements in `a` that are not in `b`.
A set containing the elements that are in `a` but not in `b`.


<a id="sets.mutable_difference"></a>

## sets.mutable_difference

<pre>
sets.mutable_difference(<a href="#sets.mutable_difference-a">a</a>, <a href="#sets.mutable_difference-b">b</a>)
</pre>

Modify set `a` removing elements from `b` from it.

**PARAMETERS**


| Name | Description | Default Value |
| :------------- | :------------- | :------------- |
| <a id="sets.mutable_difference-a"></a>a | A set, as returned by <code>sets.make()</code>. | none |
| <a id="sets.mutable_difference-b"></a>b | A set, as returned by <code>sets.make()</code>. | none |

**RETURNS**

The set `a` with all elements appearing in `b` removed from it.


<a id="sets.length"></a>

## sets.length
Expand Down
30 changes: 30 additions & 0 deletions lib/new_sets.bzl
Expand Up @@ -189,6 +189,19 @@ def _union(*args):
"""
return struct(_values = dicts.add(*[s._values for s in args]))

def _mutable_union(a, b):
"""Modify set `a` adding elements from `b` to it.

Args:
a: A set, as returned by `sets.make()`.
b: A set, as returned by `sets.make()`.

Returns:
The set `a` with all elements appearing in `b` added to it.
"""
a._values.update(b._values)
return a

def _difference(a, b):
"""Returns the elements in `a` that are not in `b`.

Expand All @@ -201,6 +214,21 @@ def _difference(a, b):
"""
return struct(_values = {e: None for e in a._values.keys() if e not in b._values})

def _mutable_difference(a, b):
"""Modify set `a` removing elements from `b` from it.

Args:
a: A set, as returned by `sets.make()`.
b: A set, as returned by `sets.make()`.

Returns:
The set `a` with all elements appearing in `b` removed from it.
"""
for item in b._values.keys():
if item in a._values:
a._values.pop(item)
return a

def _length(s):
"""Returns the number of elements in a set.

Expand Down Expand Up @@ -234,7 +262,9 @@ sets = struct(
disjoint = _disjoint,
intersection = _intersection,
union = _union,
mutable_union = _mutable_union,
difference = _difference,
mutable_difference = _mutable_difference,
length = _length,
remove = _remove,
repr = _repr,
Expand Down
66 changes: 66 additions & 0 deletions tests/new_sets_tests.bzl
Expand Up @@ -114,6 +114,38 @@ def _union_test(ctx):

union_test = unittest.make(_union_test)

def _mutable_union_test(ctx):
"""Unit tests for sets.union."""
env = unittest.begin(ctx)

s = sets.make()
s = sets.mutable_union(s, sets.make())
asserts.new_set_equals(env, sets.make(), s)
s = sets.make()
s = sets.mutable_union(s, sets.make([1]))
asserts.new_set_equals(env, sets.make([1]), s)
s = sets.make([1])
s = sets.mutable_union(s, sets.make())
asserts.new_set_equals(env, sets.make([1]), s)
s = sets.make([1])
s = sets.mutable_union(s, sets.make([1]))
asserts.new_set_equals(env, sets.make([1]), s)
s = sets.make([1])
s = sets.mutable_union(s, sets.make([1, 2]))
asserts.new_set_equals(env, sets.make([1, 2]), s)
s = sets.make([1])
s = sets.mutable_union(s, sets.make([2]))
asserts.new_set_equals(env, sets.make([1, 2]), s)

# If passing a list, verify that duplicate elements are ignored.
s = sets.make([1, 1])
s = sets.mutable_union(s, sets.make([1, 2]))
asserts.new_set_equals(env, sets.make([1, 2]), s)

return unittest.end(env)

mutable_union_test = unittest.make(_mutable_union_test)

def _difference_test(ctx):
"""Unit tests for sets.difference."""
env = unittest.begin(ctx)
Expand All @@ -132,6 +164,38 @@ def _difference_test(ctx):

difference_test = unittest.make(_difference_test)

def _mutable_difference_test(ctx):
"""Unit tests for sets.difference."""
env = unittest.begin(ctx)

s = sets.make()
s = sets.mutable_difference(s, sets.make())
asserts.new_set_equals(env, sets.make(), s)
s = sets.make()
s = sets.mutable_difference(s, sets.make([1]))
asserts.new_set_equals(env, sets.make(), s)
s = sets.make([1])
s = sets.mutable_difference(s, sets.make())
asserts.new_set_equals(env, sets.make([1]), s)
s = sets.make([1])
s = sets.mutable_difference(s, sets.make([1]))
asserts.new_set_equals(env, sets.make(), s)
s = sets.make([1])
s = sets.mutable_difference(s, sets.make([1, 2]))
asserts.new_set_equals(env, sets.make(), s)
s = sets.make([1])
s = sets.mutable_difference(s, sets.make([2]))
asserts.new_set_equals(env, sets.make([1]), s)

# If passing a list, verify that duplicate elements are ignored.
s = sets.make([1, 2])
s = sets.mutable_difference(s, sets.make([1, 1]))
asserts.new_set_equals(env, sets.make([2]), s)

return unittest.end(env)

mutable_difference_test = unittest.make(_mutable_difference_test)

def _to_list_test(ctx):
"""Unit tests for sets.to_list."""
env = unittest.begin(ctx)
Expand Down Expand Up @@ -257,7 +321,9 @@ def new_sets_test_suite():
is_equal_test,
is_subset_test,
difference_test,
mutable_difference_test,
union_test,
mutable_union_test,
to_list_test,
make_test,
copy_test,
Expand Down