Skip to content

Core Reference

graphique.core.ListChunk

Bases: BaseListArray

Source code in graphique/core.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
class ListChunk(pa.lib.BaseListArray):
    def from_counts(counts: pa.IntegerArray, values: pa.Array) -> pa.LargeListArray:
        """Return list array by converting counts into offsets."""
        mask = None
        if counts.null_count:
            mask, counts = counts.is_null(), counts.fill_null(0)
        offsets = pa.concat_arrays([pa.array([0], counts.type), pc.cumulative_sum_checked(counts)])
        cls = pa.LargeListArray if offsets.type == 'int64' else pa.ListArray
        return cls.from_arrays(offsets, values, mask=mask)

    def from_scalars(values: Iterable) -> pa.LargeListArray:
        """Return list array from array scalars."""
        return ListChunk.from_counts(pa.array(map(len, values)), pa.concat_arrays(values))

    def element(self, index: int) -> pa.Array:
        """element at index of each list scalar; defaults to null"""
        with contextlib.suppress(ValueError):
            return pc.list_element(self, index)
        size = -index if index < 0 else index + 1
        if isinstance(self, pa.ChunkedArray):
            self = self.combine_chunks()
        mask = np.asarray(Column.fill_null(pc.list_value_length(self), 0)) < size
        offsets = np.asarray(self.offsets[1:] if index < 0 else self.offsets[:-1])
        return pc.list_flatten(self).take(pa.array(offsets + index, mask=mask))

    def first(self) -> pa.Array:
        """first value of each list scalar"""
        return ListChunk.element(self, 0)

    def last(self) -> pa.Array:
        """last value of each list scalar"""
        return ListChunk.element(self, -1)

    def scalars(self) -> Iterable:
        empty = pa.array([], self.type.value_type)
        return (scalar.values or empty for scalar in self)

    def map_list(self, func: Callable, **kwargs) -> pa.lib.BaseListArray:
        """Return list array by mapping function across scalars, with null handling."""
        values = [func(value, **kwargs) for value in ListChunk.scalars(self)]
        return ListChunk.from_scalars(values)

    def inner_flatten(self) -> pa.lib.BaseListArray:
        """Return flattened inner lists from a nested list array."""
        offsets = self.values.offsets.take(self.offsets)
        return type(self).from_arrays(offsets, self.values.values)

    def aggregate(self, **funcs: Optional[pc.FunctionOptions]) -> pa.RecordBatch:
        """Return aggregated scalars by grouping each hash function on the parent indices.

        If there are empty or null scalars, then the result must be padded with null defaults and
        reordered. If the function is a `count`, then the default is 0.
        """
        columns = {'key': pc.list_parent_indices(self), '': pc.list_flatten(self)}
        items = [('', name, funcs[name]) for name in funcs]
        table = pa.table(columns).group_by(['key']).aggregate(items)
        indices, table = table['key'], table.remove_column(table.schema.get_field_index('key'))
        (batch,) = table.to_batches()
        if len(batch) == len(self):  # no empty or null scalars
            return batch
        mask = pc.equal(pc.list_value_length(self), 0)
        empties = pc.indices_nonzero(Column.fill_null(mask, True))
        indices = pa.chunked_array(indices.chunks + [empties.cast(indices.type)])
        columns = {}
        for field in batch.schema:
            scalar = pa.scalar(0 if 'count' in field.name else None, field.type)
            columns[field.name] = pa.repeat(scalar, len(empties))
        table = pa.concat_tables([table, pa.table(columns)]).combine_chunks()
        return table.to_batches()[0].take(pc.sort_indices(indices))

    def min_max(self, **options) -> pa.Array:
        if pa.types.is_dictionary(self.type.value_type):
            (self,) = ListChunk.aggregate(self, distinct=None)
            self = type(self).from_arrays(self.offsets, self.values.dictionary_decode())
        return ListChunk.aggregate(self, min_max=pc.ScalarAggregateOptions(**options))[0]

    def min(self, **options) -> pa.Array:
        """min value of each list scalar"""
        return ListChunk.min_max(self, **options).field('min')

    def max(self, **options) -> pa.Array:
        """max value of each list scalar"""
        return ListChunk.min_max(self, **options).field('max')

    def mode(self, **options) -> pa.Array:
        """modes of each list scalar"""
        return ListChunk.map_list(self, pc.mode, **options)

    def quantile(self, **options) -> pa.Array:
        """quantiles of each list scalar"""
        return ListChunk.map_list(self, pc.quantile, **options)

    def index(self, **options) -> pa.Array:
        """index for first occurrence of each list scalar"""
        return pa.array(pc.index(value, **options) for value in ListChunk.scalars(self))

    @register
    def list_all(ctx, self: pa.list_(pa.bool_())) -> pa.bool_():  # type: ignore
        """Test whether all elements in a boolean array evaluate to true."""
        return ListChunk.aggregate(self, all=None)[0]

    @register
    def list_any(ctx, self: pa.list_(pa.bool_())) -> pa.bool_():  # type: ignore
        """Test whether any element in a boolean array evaluates to true."""
        return ListChunk.aggregate(self, any=None)[0]

aggregate(**funcs)

Return aggregated scalars by grouping each hash function on the parent indices.

If there are empty or null scalars, then the result must be padded with null defaults and reordered. If the function is a count, then the default is 0.

Source code in graphique/core.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def aggregate(self, **funcs: Optional[pc.FunctionOptions]) -> pa.RecordBatch:
    """Return aggregated scalars by grouping each hash function on the parent indices.

    If there are empty or null scalars, then the result must be padded with null defaults and
    reordered. If the function is a `count`, then the default is 0.
    """
    columns = {'key': pc.list_parent_indices(self), '': pc.list_flatten(self)}
    items = [('', name, funcs[name]) for name in funcs]
    table = pa.table(columns).group_by(['key']).aggregate(items)
    indices, table = table['key'], table.remove_column(table.schema.get_field_index('key'))
    (batch,) = table.to_batches()
    if len(batch) == len(self):  # no empty or null scalars
        return batch
    mask = pc.equal(pc.list_value_length(self), 0)
    empties = pc.indices_nonzero(Column.fill_null(mask, True))
    indices = pa.chunked_array(indices.chunks + [empties.cast(indices.type)])
    columns = {}
    for field in batch.schema:
        scalar = pa.scalar(0 if 'count' in field.name else None, field.type)
        columns[field.name] = pa.repeat(scalar, len(empties))
    table = pa.concat_tables([table, pa.table(columns)]).combine_chunks()
    return table.to_batches()[0].take(pc.sort_indices(indices))

element(index)

element at index of each list scalar; defaults to null

Source code in graphique/core.py
122
123
124
125
126
127
128
129
130
131
def element(self, index: int) -> pa.Array:
    """element at index of each list scalar; defaults to null"""
    with contextlib.suppress(ValueError):
        return pc.list_element(self, index)
    size = -index if index < 0 else index + 1
    if isinstance(self, pa.ChunkedArray):
        self = self.combine_chunks()
    mask = np.asarray(Column.fill_null(pc.list_value_length(self), 0)) < size
    offsets = np.asarray(self.offsets[1:] if index < 0 else self.offsets[:-1])
    return pc.list_flatten(self).take(pa.array(offsets + index, mask=mask))

first()

first value of each list scalar

Source code in graphique/core.py
133
134
135
def first(self) -> pa.Array:
    """first value of each list scalar"""
    return ListChunk.element(self, 0)

from_counts(counts, values)

Return list array by converting counts into offsets.

Source code in graphique/core.py
109
110
111
112
113
114
115
116
def from_counts(counts: pa.IntegerArray, values: pa.Array) -> pa.LargeListArray:
    """Return list array by converting counts into offsets."""
    mask = None
    if counts.null_count:
        mask, counts = counts.is_null(), counts.fill_null(0)
    offsets = pa.concat_arrays([pa.array([0], counts.type), pc.cumulative_sum_checked(counts)])
    cls = pa.LargeListArray if offsets.type == 'int64' else pa.ListArray
    return cls.from_arrays(offsets, values, mask=mask)

from_scalars(values)

Return list array from array scalars.

Source code in graphique/core.py
118
119
120
def from_scalars(values: Iterable) -> pa.LargeListArray:
    """Return list array from array scalars."""
    return ListChunk.from_counts(pa.array(map(len, values)), pa.concat_arrays(values))

index(**options)

index for first occurrence of each list scalar

Source code in graphique/core.py
200
201
202
def index(self, **options) -> pa.Array:
    """index for first occurrence of each list scalar"""
    return pa.array(pc.index(value, **options) for value in ListChunk.scalars(self))

inner_flatten()

Return flattened inner lists from a nested list array.

Source code in graphique/core.py
150
151
152
153
def inner_flatten(self) -> pa.lib.BaseListArray:
    """Return flattened inner lists from a nested list array."""
    offsets = self.values.offsets.take(self.offsets)
    return type(self).from_arrays(offsets, self.values.values)

last()

last value of each list scalar

Source code in graphique/core.py
137
138
139
def last(self) -> pa.Array:
    """last value of each list scalar"""
    return ListChunk.element(self, -1)

list_all(ctx, self)

Test whether all elements in a boolean array evaluate to true.

Source code in graphique/core.py
204
205
206
207
@register
def list_all(ctx, self: pa.list_(pa.bool_())) -> pa.bool_():  # type: ignore
    """Test whether all elements in a boolean array evaluate to true."""
    return ListChunk.aggregate(self, all=None)[0]

list_any(ctx, self)

Test whether any element in a boolean array evaluates to true.

Source code in graphique/core.py
209
210
211
212
@register
def list_any(ctx, self: pa.list_(pa.bool_())) -> pa.bool_():  # type: ignore
    """Test whether any element in a boolean array evaluates to true."""
    return ListChunk.aggregate(self, any=None)[0]

map_list(func, **kwargs)

Return list array by mapping function across scalars, with null handling.

Source code in graphique/core.py
145
146
147
148
def map_list(self, func: Callable, **kwargs) -> pa.lib.BaseListArray:
    """Return list array by mapping function across scalars, with null handling."""
    values = [func(value, **kwargs) for value in ListChunk.scalars(self)]
    return ListChunk.from_scalars(values)

max(**options)

max value of each list scalar

Source code in graphique/core.py
188
189
190
def max(self, **options) -> pa.Array:
    """max value of each list scalar"""
    return ListChunk.min_max(self, **options).field('max')

min(**options)

min value of each list scalar

Source code in graphique/core.py
184
185
186
def min(self, **options) -> pa.Array:
    """min value of each list scalar"""
    return ListChunk.min_max(self, **options).field('min')

mode(**options)

modes of each list scalar

Source code in graphique/core.py
192
193
194
def mode(self, **options) -> pa.Array:
    """modes of each list scalar"""
    return ListChunk.map_list(self, pc.mode, **options)

quantile(**options)

quantiles of each list scalar

Source code in graphique/core.py
196
197
198
def quantile(self, **options) -> pa.Array:
    """quantiles of each list scalar"""
    return ListChunk.map_list(self, pc.quantile, **options)

graphique.core.Column

Bases: ChunkedArray

Chunked array interface as a namespace of functions.

Source code in graphique/core.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
class Column(pa.ChunkedArray):
    """Chunked array interface as a namespace of functions."""

    def is_list_type(self):
        funcs = pa.types.is_list, pa.types.is_large_list, pa.types.is_fixed_size_list
        return any(func(self.type) for func in funcs)

    def call_indices(self, func: Callable) -> Array:
        if not pa.types.is_dictionary(self.type):
            return func(self)
        array = self.combine_chunks()
        return pa.DictionaryArray.from_arrays(func(array.indices), array.dictionary)

    def fill_null_backward(self) -> Array:
        """`fill_null_backward` with dictionary support."""
        return Column.call_indices(self, pc.fill_null_backward)

    def fill_null_forward(self) -> Array:
        """`fill_null_forward` with dictionary support."""
        return Column.call_indices(self, pc.fill_null_forward)

    def fill_null(self, value) -> pa.ChunkedArray:
        """Optimized `fill_null` to check `null_count`."""
        return self.fill_null(value) if self.null_count else self

    def sort_values(self) -> Array:
        if not pa.types.is_dictionary(self.type):
            return self
        array = self if isinstance(self, pa.Array) else self.combine_chunks()
        return pc.rank(array.dictionary, 'ascending').take(array.indices)

    def pairwise_diff(self, period: int = 1) -> Array:
        """`pairwise_diff` with chunked array support."""
        return pc.pairwise_diff(self.combine_chunks(), period)

    def diff(self, func: Callable = pc.subtract, period: int = 1) -> Array:
        """Compute first order difference of an array.

        Unlike `pairwise_diff`, does not return leading nulls.
        """
        return func(self[period:], self[:-period])

    def run_offsets(self, predicate: Callable = pc.not_equal, *args) -> pa.IntegerArray:
        """Run-end encode array with leading zero, suitable for list offsets.

        Args:
            predicate: binary function applied to adjacent values
            *args: apply binary function to scalar, using `subtract` as the difference function
        """
        ends = [pa.array([True])]
        mask = predicate(Column.diff(self), *args) if args else Column.diff(self, predicate)
        return pc.indices_nonzero(pa.chunked_array(ends + mask.chunks + ends))

    def index(self, value, start=0, end=None) -> int:
        """Return the first index of a value."""
        with contextlib.suppress(NotImplementedError):
            return self.index(value, start, end).as_py()  # type: ignore
        offset = start
        for chunk in self[start:end].iterchunks():
            index = chunk.dictionary.index(value).as_py()
            if index >= 0:
                index = chunk.indices.index(index).as_py()
            if index >= 0:
                return offset + index
            offset += len(chunk)
        return -1

    def range(self, lower=None, upper=None, include_lower=True, include_upper=False) -> slice:
        """Return slice within range from a sorted array, by default a half-open interval."""
        method = bisect.bisect_left if include_lower else bisect.bisect_right
        start = 0 if lower is None else method(self, Compare(lower))
        method = bisect.bisect_right if include_upper else bisect.bisect_left
        stop = None if upper is None else method(self, Compare(upper), start)
        return slice(start, stop)

    def find(self, *values) -> Iterator[slice]:
        """Generate slices of matching rows from a sorted array."""
        stop = 0
        for value in map(Compare, sorted(values)):
            start = bisect.bisect_left(self, value, stop)
            stop = bisect.bisect_right(self, value, start)
            yield slice(start, stop)

diff(func=pc.subtract, period=1)

Compute first order difference of an array.

Unlike pairwise_diff, does not return leading nulls.

Source code in graphique/core.py
250
251
252
253
254
255
def diff(self, func: Callable = pc.subtract, period: int = 1) -> Array:
    """Compute first order difference of an array.

    Unlike `pairwise_diff`, does not return leading nulls.
    """
    return func(self[period:], self[:-period])

fill_null(value)

Optimized fill_null to check null_count.

Source code in graphique/core.py
236
237
238
def fill_null(self, value) -> pa.ChunkedArray:
    """Optimized `fill_null` to check `null_count`."""
    return self.fill_null(value) if self.null_count else self

fill_null_backward()

fill_null_backward with dictionary support.

Source code in graphique/core.py
228
229
230
def fill_null_backward(self) -> Array:
    """`fill_null_backward` with dictionary support."""
    return Column.call_indices(self, pc.fill_null_backward)

fill_null_forward()

fill_null_forward with dictionary support.

Source code in graphique/core.py
232
233
234
def fill_null_forward(self) -> Array:
    """`fill_null_forward` with dictionary support."""
    return Column.call_indices(self, pc.fill_null_forward)

find(*values)

Generate slices of matching rows from a sorted array.

Source code in graphique/core.py
290
291
292
293
294
295
296
def find(self, *values) -> Iterator[slice]:
    """Generate slices of matching rows from a sorted array."""
    stop = 0
    for value in map(Compare, sorted(values)):
        start = bisect.bisect_left(self, value, stop)
        stop = bisect.bisect_right(self, value, start)
        yield slice(start, stop)

index(value, start=0, end=None)

Return the first index of a value.

Source code in graphique/core.py
268
269
270
271
272
273
274
275
276
277
278
279
280
def index(self, value, start=0, end=None) -> int:
    """Return the first index of a value."""
    with contextlib.suppress(NotImplementedError):
        return self.index(value, start, end).as_py()  # type: ignore
    offset = start
    for chunk in self[start:end].iterchunks():
        index = chunk.dictionary.index(value).as_py()
        if index >= 0:
            index = chunk.indices.index(index).as_py()
        if index >= 0:
            return offset + index
        offset += len(chunk)
    return -1

pairwise_diff(period=1)

pairwise_diff with chunked array support.

Source code in graphique/core.py
246
247
248
def pairwise_diff(self, period: int = 1) -> Array:
    """`pairwise_diff` with chunked array support."""
    return pc.pairwise_diff(self.combine_chunks(), period)

range(lower=None, upper=None, include_lower=True, include_upper=False)

Return slice within range from a sorted array, by default a half-open interval.

Source code in graphique/core.py
282
283
284
285
286
287
288
def range(self, lower=None, upper=None, include_lower=True, include_upper=False) -> slice:
    """Return slice within range from a sorted array, by default a half-open interval."""
    method = bisect.bisect_left if include_lower else bisect.bisect_right
    start = 0 if lower is None else method(self, Compare(lower))
    method = bisect.bisect_right if include_upper else bisect.bisect_left
    stop = None if upper is None else method(self, Compare(upper), start)
    return slice(start, stop)

run_offsets(predicate=pc.not_equal, *args)

Run-end encode array with leading zero, suitable for list offsets.

Parameters:

Name Type Description Default
predicate Callable

binary function applied to adjacent values

not_equal
*args

apply binary function to scalar, using subtract as the difference function

()
Source code in graphique/core.py
257
258
259
260
261
262
263
264
265
266
def run_offsets(self, predicate: Callable = pc.not_equal, *args) -> pa.IntegerArray:
    """Run-end encode array with leading zero, suitable for list offsets.

    Args:
        predicate: binary function applied to adjacent values
        *args: apply binary function to scalar, using `subtract` as the difference function
    """
    ends = [pa.array([True])]
    mask = predicate(Column.diff(self), *args) if args else Column.diff(self, predicate)
    return pc.indices_nonzero(pa.chunked_array(ends + mask.chunks + ends))

graphique.core.Table

Bases: Table

Table interface as a namespace of functions.

Source code in graphique/core.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
class Table(pa.Table):
    """Table interface as a namespace of functions."""

    def map_batch(self, func: Callable, *args, **kwargs) -> pa.Table:
        return pa.Table.from_batches(func(batch, *args, **kwargs) for batch in self.to_batches())

    def columns(self) -> dict:
        """Return columns as a dictionary."""
        return dict(zip(self.schema.names, self))

    def union(*tables: Batch) -> Batch:
        """Return table with union of columns."""
        columns: dict = {}
        for table in tables:
            columns |= Table.columns(table)
        return type(tables[0]).from_pydict(columns)

    def range(self, name: str, lower=None, upper=None, **includes) -> pa.Table:
        """Return rows within range, by default a half-open interval.

        Assumes the table is sorted by the column name, i.e., indexed.
        """
        return self[Column.range(self[name], lower, upper, **includes)]

    def is_in(self, name: str, *values) -> pa.Table:
        """Return rows which matches one of the values.

        Assumes the table is sorted by the column name, i.e., indexed.
        """
        slices = list(Column.find(self[name], *values)) or [slice(0)]
        return pa.concat_tables(self[slc] for slc in slices)

    def not_equal(self, name: str, value) -> pa.Table:
        """Return rows which don't match the value.

        Assumes the table is sorted by the column name, i.e., indexed.
        """
        (slc,) = Column.find(self[name], value)
        return pa.concat_tables([self[: slc.start], self[slc.stop :]])

    def from_offsets(self, offsets: pa.IntegerArray, mask=None) -> pa.RecordBatch:
        """Return record batch with columns converted into list columns."""
        cls = pa.LargeListArray if offsets.type == 'int64' else pa.ListArray
        if isinstance(self, pa.Table):
            (self,) = self.combine_chunks().to_batches() or [pa.record_batch([], self.schema)]
        arrays = [cls.from_arrays(offsets, array, mask=mask) for array in self]
        return pa.RecordBatch.from_arrays(arrays, self.schema.names)

    def from_counts(self, counts: pa.IntegerArray) -> pa.RecordBatch:
        """Return record batch with columns converted into list columns."""
        mask = None
        if counts.null_count:
            mask, counts = counts.is_null(), counts.fill_null(0)
        offsets = pa.concat_arrays([pa.array([0], counts.type), pc.cumulative_sum_checked(counts)])
        return Table.from_offsets(self, offsets, mask=mask)

    def runs(self, *names: str, **predicates: tuple) -> tuple:
        """Return table grouped by pairwise differences, and corresponding counts.

        Args:
            *names: columns to partition by `not_equal` which will return scalars
            **predicates: pairwise predicates with optional args which will return list arrays;
                if the predicate has args, it will be called on the differences
        """
        offsets = pa.chunked_array(
            Column.run_offsets(self[name], *predicates.get(name, ()))
            for name in names + tuple(predicates)
        )
        offsets = offsets.unique().sort()
        scalars = self.select(names).take(offsets[:-1])
        lists = self.select(set(self.schema.names) - set(names))
        table = Table.union(scalars, Table.from_offsets(lists, offsets))
        return table, Column.diff(offsets)

    def list_fields(self) -> set:
        return {field.name for field in self.schema if Column.is_list_type(field)}

    def list_value_length(self) -> pa.Array:
        lists = Table.list_fields(self)
        if not lists:
            raise ValueError(f"no list columns available: {self.schema.names}")
        counts, *others = (pc.list_value_length(self[name]) for name in lists)
        if any(counts != other for other in others):
            raise ValueError(f"list columns have different value lengths: {lists}")
        return counts if isinstance(counts, pa.Array) else counts.chunk(0)

    def map_list(self, func: Callable, *args, **kwargs) -> Batch:
        """Return table with function mapped across list scalars."""
        batches: Iterable = Table.split(self.select(Table.list_fields(self)))
        batches = [None if batch is None else func(batch, *args, **kwargs) for batch in batches]
        counts = pa.array(None if batch is None else len(batch) for batch in batches)
        table = pa.Table.from_batches(batch for batch in batches if batch is not None)
        return Table.union(self, Table.from_counts(table, counts))

    def sort_indices(
        self, *names: str, length: Optional[int] = None, null_placement: str = 'at_end'
    ) -> pa.Array:
        """Return indices which would sort the table by columns, optimized for fixed length."""
        func = functools.partial(pc.sort_indices, null_placement=null_placement)
        if length is not None and length < len(self):
            func = functools.partial(pc.select_k_unstable, k=length)
        keys = dict(map(sort_key, names))
        table = pa.table({name: Column.sort_values(self[name]) for name in keys})
        return func(table, sort_keys=keys.items()) if table else pa.array([], 'int64')

    def sort(
        self,
        *names: str,
        length: Optional[int] = None,
        indices: str = '',
        null_placement: str = 'at_end',
    ) -> Batch:
        """Return table sorted by columns, optimized for fixed length.

        Args:
            *names: columns to sort by
            length: maximum number of rows to return
            indices: include original indices in the table
        """
        if length == 1 and not indices:
            return Table.min_max(self, *names)[:1]
        indices_ = Table.sort_indices(self, *names, length=length, null_placement=null_placement)
        table = self.take(indices_)
        if indices:
            table = table.append_column(indices, indices_)
        func = lambda name: not name.startswith('-') and not self[name].null_count  # noqa: E731
        metadata = {'index_columns': list(itertools.takewhile(func, names))}
        return table.replace_schema_metadata({'pandas': json.dumps(metadata)})

    def filter_list(self, expr: ds.Expression) -> Batch:
        """Return table with list columns filtered within scalars."""
        fields = Table.list_fields(self)
        tables = [
            None if batch is None else pa.Table.from_batches([batch]).filter(expr).select(fields)
            for batch in Table.split(self)
        ]
        counts = pa.array(None if table is None else len(table) for table in tables)
        table = pa.concat_tables(table for table in tables if table is not None)
        return Table.union(self, Table.from_counts(table, counts))

    def min_max(self, *names: str) -> Self:
        """Return table filtered by minimum or maximum values."""
        for key, order in map(sort_key, names):
            field, asc = pc.field(key), (order == 'ascending')
            ((value,),) = Nodes.group(self, _=(key, ('min' if asc else 'max'), None)).to_table()
            self = self.filter(field <= value if asc else field >= value)
        return self

    def rank(self, k: int, *names: str) -> Self:
        """Return table filtered by values within dense rank, similar to `select_k_unstable`."""
        if k == 1:
            return Table.min_max(self, *names)
        keys = dict(map(sort_key, names))
        table = Nodes.group(self, *keys).to_table()
        table = table.take(pc.select_k_unstable(table, k, keys.items()))
        exprs = []
        for key, order in keys.items():
            field, asc = pc.field(key), (order == 'ascending')
            exprs.append(field <= pc.max(table[key]) if asc else field >= pc.min(table[key]))
        return self.filter(bit_all(exprs))

    def fragments(self, *names, counts: str = '') -> pa.Table:
        """Return selected fragment keys in a table."""
        try:
            expr = self._scan_options.get('filter')
            if expr is not None:  # raise ValueError if filter references other fields
                ds.dataset([], schema=self.partitioning.schema).scanner(filter=expr)
        except (AttributeError, ValueError):
            return pa.table({})
        fragments = self._get_fragments(expr)
        parts = [ds.get_partition_keys(frag.partition_expression) for frag in fragments]
        names, table = set(names), pa.Table.from_pylist(parts)  # type: ignore
        keys = [name for name in table.schema.names if name in names]
        table = table.group_by(keys, use_threads=False).aggregate([])
        if not counts:
            return table
        if not table.schema:
            return table.append_column(counts, pa.array([self.count_rows()]))
        exprs = [bit_all(pc.field(key) == row[key] for key in row) for row in table.to_pylist()]
        column = [self.filter(expr).count_rows() for expr in exprs]
        return table.append_column(counts, pa.array(column))

    def rank_keys(self, k: int, *names: str, dense: bool = True) -> tuple:
        """Return expression and unmatched fields for partitioned dataset which filters by rank.

        Args:
            k: max dense rank or length
            *names: columns to rank by
            dense: use dense rank; false indicates sorting
        """
        keys = dict(map(sort_key, names))
        table = Table.fragments(self, *keys, counts='' if dense else '_')
        keys = {name: keys[name] for name in table.schema.names if name in keys}
        if not keys:
            return None, names
        if dense:
            table = table.take(pc.select_k_unstable(table, k, keys.items()))
        else:
            table = table.sort_by(keys.items())
            totals = itertools.accumulate(table['_'].to_pylist())
            counts = (count for count, total in enumerate(totals, 1) if total >= k)
            table = table[: next(counts, None)].remove_column(len(table) - 1)
        exprs = [bit_all(pc.field(key) == row[key] for key in row) for row in table.to_pylist()]
        remaining = names[len(keys) :]
        if remaining or not dense:  # fields with a single value are no longer needed
            selectors = [len(table[key].unique()) > 1 for key in keys]
            remaining = tuple(itertools.compress(names, selectors)) + remaining
        return bit_any(exprs[: len(table)]), remaining

    def flatten(self, indices: str = '') -> Iterator[pa.RecordBatch]:
        """Generate batches with list arrays flattened, optionally with parent indices."""
        offset = 0
        for batch in self.to_batches():
            _ = Table.list_value_length(batch)
            indices_ = pc.list_parent_indices(batch[Table.list_fields(batch).pop()])
            arrays = [
                pc.list_flatten(array) if Column.is_list_type(array) else array.take(indices_)
                for array in batch
            ]
            columns = dict(zip(batch.schema.names, arrays))
            if indices:
                columns[indices] = pc.add(indices_, offset)
            offset += len(batch)
            yield pa.RecordBatch.from_pydict(columns)

    def split(self) -> Iterator[Optional[pa.RecordBatch]]:
        """Generate tables from splitting list scalars."""
        lists = Table.list_fields(self)
        scalars = set(self.schema.names) - lists
        for index, count in enumerate(Table.list_value_length(self).to_pylist()):
            if count is None:
                yield None
            else:
                row = {name: pa.repeat(self[name][index], count) for name in scalars}
                row |= {name: self[name][index].values for name in lists}
                yield pa.RecordBatch.from_pydict(row)

    def size(self) -> str:
        """Return buffer size in readable units."""
        size, prefix = self.nbytes, ''
        for prefix in itertools.takewhile(lambda _: size >= 1e3, 'kMGT'):
            size /= 1e3
        return f'{size:n} {prefix}B'

columns()

Return columns as a dictionary.

Source code in graphique/core.py
305
306
307
def columns(self) -> dict:
    """Return columns as a dictionary."""
    return dict(zip(self.schema.names, self))

filter_list(expr)

Return table with list columns filtered within scalars.

Source code in graphique/core.py
428
429
430
431
432
433
434
435
436
437
def filter_list(self, expr: ds.Expression) -> Batch:
    """Return table with list columns filtered within scalars."""
    fields = Table.list_fields(self)
    tables = [
        None if batch is None else pa.Table.from_batches([batch]).filter(expr).select(fields)
        for batch in Table.split(self)
    ]
    counts = pa.array(None if table is None else len(table) for table in tables)
    table = pa.concat_tables(table for table in tables if table is not None)
    return Table.union(self, Table.from_counts(table, counts))

flatten(indices='')

Generate batches with list arrays flattened, optionally with parent indices.

Source code in graphique/core.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def flatten(self, indices: str = '') -> Iterator[pa.RecordBatch]:
    """Generate batches with list arrays flattened, optionally with parent indices."""
    offset = 0
    for batch in self.to_batches():
        _ = Table.list_value_length(batch)
        indices_ = pc.list_parent_indices(batch[Table.list_fields(batch).pop()])
        arrays = [
            pc.list_flatten(array) if Column.is_list_type(array) else array.take(indices_)
            for array in batch
        ]
        columns = dict(zip(batch.schema.names, arrays))
        if indices:
            columns[indices] = pc.add(indices_, offset)
        offset += len(batch)
        yield pa.RecordBatch.from_pydict(columns)

fragments(*names, counts='')

Return selected fragment keys in a table.

Source code in graphique/core.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def fragments(self, *names, counts: str = '') -> pa.Table:
    """Return selected fragment keys in a table."""
    try:
        expr = self._scan_options.get('filter')
        if expr is not None:  # raise ValueError if filter references other fields
            ds.dataset([], schema=self.partitioning.schema).scanner(filter=expr)
    except (AttributeError, ValueError):
        return pa.table({})
    fragments = self._get_fragments(expr)
    parts = [ds.get_partition_keys(frag.partition_expression) for frag in fragments]
    names, table = set(names), pa.Table.from_pylist(parts)  # type: ignore
    keys = [name for name in table.schema.names if name in names]
    table = table.group_by(keys, use_threads=False).aggregate([])
    if not counts:
        return table
    if not table.schema:
        return table.append_column(counts, pa.array([self.count_rows()]))
    exprs = [bit_all(pc.field(key) == row[key] for key in row) for row in table.to_pylist()]
    column = [self.filter(expr).count_rows() for expr in exprs]
    return table.append_column(counts, pa.array(column))

from_counts(counts)

Return record batch with columns converted into list columns.

Source code in graphique/core.py
347
348
349
350
351
352
353
def from_counts(self, counts: pa.IntegerArray) -> pa.RecordBatch:
    """Return record batch with columns converted into list columns."""
    mask = None
    if counts.null_count:
        mask, counts = counts.is_null(), counts.fill_null(0)
    offsets = pa.concat_arrays([pa.array([0], counts.type), pc.cumulative_sum_checked(counts)])
    return Table.from_offsets(self, offsets, mask=mask)

from_offsets(offsets, mask=None)

Return record batch with columns converted into list columns.

Source code in graphique/core.py
339
340
341
342
343
344
345
def from_offsets(self, offsets: pa.IntegerArray, mask=None) -> pa.RecordBatch:
    """Return record batch with columns converted into list columns."""
    cls = pa.LargeListArray if offsets.type == 'int64' else pa.ListArray
    if isinstance(self, pa.Table):
        (self,) = self.combine_chunks().to_batches() or [pa.record_batch([], self.schema)]
    arrays = [cls.from_arrays(offsets, array, mask=mask) for array in self]
    return pa.RecordBatch.from_arrays(arrays, self.schema.names)

is_in(name, *values)

Return rows which matches one of the values.

Assumes the table is sorted by the column name, i.e., indexed.

Source code in graphique/core.py
323
324
325
326
327
328
329
def is_in(self, name: str, *values) -> pa.Table:
    """Return rows which matches one of the values.

    Assumes the table is sorted by the column name, i.e., indexed.
    """
    slices = list(Column.find(self[name], *values)) or [slice(0)]
    return pa.concat_tables(self[slc] for slc in slices)

map_list(func, *args, **kwargs)

Return table with function mapped across list scalars.

Source code in graphique/core.py
385
386
387
388
389
390
391
def map_list(self, func: Callable, *args, **kwargs) -> Batch:
    """Return table with function mapped across list scalars."""
    batches: Iterable = Table.split(self.select(Table.list_fields(self)))
    batches = [None if batch is None else func(batch, *args, **kwargs) for batch in batches]
    counts = pa.array(None if batch is None else len(batch) for batch in batches)
    table = pa.Table.from_batches(batch for batch in batches if batch is not None)
    return Table.union(self, Table.from_counts(table, counts))

min_max(*names)

Return table filtered by minimum or maximum values.

Source code in graphique/core.py
439
440
441
442
443
444
445
def min_max(self, *names: str) -> Self:
    """Return table filtered by minimum or maximum values."""
    for key, order in map(sort_key, names):
        field, asc = pc.field(key), (order == 'ascending')
        ((value,),) = Nodes.group(self, _=(key, ('min' if asc else 'max'), None)).to_table()
        self = self.filter(field <= value if asc else field >= value)
    return self

not_equal(name, value)

Return rows which don't match the value.

Assumes the table is sorted by the column name, i.e., indexed.

Source code in graphique/core.py
331
332
333
334
335
336
337
def not_equal(self, name: str, value) -> pa.Table:
    """Return rows which don't match the value.

    Assumes the table is sorted by the column name, i.e., indexed.
    """
    (slc,) = Column.find(self[name], value)
    return pa.concat_tables([self[: slc.start], self[slc.stop :]])

range(name, lower=None, upper=None, **includes)

Return rows within range, by default a half-open interval.

Assumes the table is sorted by the column name, i.e., indexed.

Source code in graphique/core.py
316
317
318
319
320
321
def range(self, name: str, lower=None, upper=None, **includes) -> pa.Table:
    """Return rows within range, by default a half-open interval.

    Assumes the table is sorted by the column name, i.e., indexed.
    """
    return self[Column.range(self[name], lower, upper, **includes)]

rank(k, *names)

Return table filtered by values within dense rank, similar to select_k_unstable.

Source code in graphique/core.py
447
448
449
450
451
452
453
454
455
456
457
458
def rank(self, k: int, *names: str) -> Self:
    """Return table filtered by values within dense rank, similar to `select_k_unstable`."""
    if k == 1:
        return Table.min_max(self, *names)
    keys = dict(map(sort_key, names))
    table = Nodes.group(self, *keys).to_table()
    table = table.take(pc.select_k_unstable(table, k, keys.items()))
    exprs = []
    for key, order in keys.items():
        field, asc = pc.field(key), (order == 'ascending')
        exprs.append(field <= pc.max(table[key]) if asc else field >= pc.min(table[key]))
    return self.filter(bit_all(exprs))

rank_keys(k, *names, dense=True)

Return expression and unmatched fields for partitioned dataset which filters by rank.

Parameters:

Name Type Description Default
k int

max dense rank or length

required
*names str

columns to rank by

()
dense bool

use dense rank; false indicates sorting

True
Source code in graphique/core.py
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def rank_keys(self, k: int, *names: str, dense: bool = True) -> tuple:
    """Return expression and unmatched fields for partitioned dataset which filters by rank.

    Args:
        k: max dense rank or length
        *names: columns to rank by
        dense: use dense rank; false indicates sorting
    """
    keys = dict(map(sort_key, names))
    table = Table.fragments(self, *keys, counts='' if dense else '_')
    keys = {name: keys[name] for name in table.schema.names if name in keys}
    if not keys:
        return None, names
    if dense:
        table = table.take(pc.select_k_unstable(table, k, keys.items()))
    else:
        table = table.sort_by(keys.items())
        totals = itertools.accumulate(table['_'].to_pylist())
        counts = (count for count, total in enumerate(totals, 1) if total >= k)
        table = table[: next(counts, None)].remove_column(len(table) - 1)
    exprs = [bit_all(pc.field(key) == row[key] for key in row) for row in table.to_pylist()]
    remaining = names[len(keys) :]
    if remaining or not dense:  # fields with a single value are no longer needed
        selectors = [len(table[key].unique()) > 1 for key in keys]
        remaining = tuple(itertools.compress(names, selectors)) + remaining
    return bit_any(exprs[: len(table)]), remaining

runs(*names, **predicates)

Return table grouped by pairwise differences, and corresponding counts.

Parameters:

Name Type Description Default
*names str

columns to partition by not_equal which will return scalars

()
**predicates tuple

pairwise predicates with optional args which will return list arrays; if the predicate has args, it will be called on the differences

{}
Source code in graphique/core.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def runs(self, *names: str, **predicates: tuple) -> tuple:
    """Return table grouped by pairwise differences, and corresponding counts.

    Args:
        *names: columns to partition by `not_equal` which will return scalars
        **predicates: pairwise predicates with optional args which will return list arrays;
            if the predicate has args, it will be called on the differences
    """
    offsets = pa.chunked_array(
        Column.run_offsets(self[name], *predicates.get(name, ()))
        for name in names + tuple(predicates)
    )
    offsets = offsets.unique().sort()
    scalars = self.select(names).take(offsets[:-1])
    lists = self.select(set(self.schema.names) - set(names))
    table = Table.union(scalars, Table.from_offsets(lists, offsets))
    return table, Column.diff(offsets)

size()

Return buffer size in readable units.

Source code in graphique/core.py
536
537
538
539
540
541
def size(self) -> str:
    """Return buffer size in readable units."""
    size, prefix = self.nbytes, ''
    for prefix in itertools.takewhile(lambda _: size >= 1e3, 'kMGT'):
        size /= 1e3
    return f'{size:n} {prefix}B'

sort(*names, length=None, indices='', null_placement='at_end')

Return table sorted by columns, optimized for fixed length.

Parameters:

Name Type Description Default
*names str

columns to sort by

()
length Optional[int]

maximum number of rows to return

None
indices str

include original indices in the table

''
Source code in graphique/core.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def sort(
    self,
    *names: str,
    length: Optional[int] = None,
    indices: str = '',
    null_placement: str = 'at_end',
) -> Batch:
    """Return table sorted by columns, optimized for fixed length.

    Args:
        *names: columns to sort by
        length: maximum number of rows to return
        indices: include original indices in the table
    """
    if length == 1 and not indices:
        return Table.min_max(self, *names)[:1]
    indices_ = Table.sort_indices(self, *names, length=length, null_placement=null_placement)
    table = self.take(indices_)
    if indices:
        table = table.append_column(indices, indices_)
    func = lambda name: not name.startswith('-') and not self[name].null_count  # noqa: E731
    metadata = {'index_columns': list(itertools.takewhile(func, names))}
    return table.replace_schema_metadata({'pandas': json.dumps(metadata)})

sort_indices(*names, length=None, null_placement='at_end')

Return indices which would sort the table by columns, optimized for fixed length.

Source code in graphique/core.py
393
394
395
396
397
398
399
400
401
402
def sort_indices(
    self, *names: str, length: Optional[int] = None, null_placement: str = 'at_end'
) -> pa.Array:
    """Return indices which would sort the table by columns, optimized for fixed length."""
    func = functools.partial(pc.sort_indices, null_placement=null_placement)
    if length is not None and length < len(self):
        func = functools.partial(pc.select_k_unstable, k=length)
    keys = dict(map(sort_key, names))
    table = pa.table({name: Column.sort_values(self[name]) for name in keys})
    return func(table, sort_keys=keys.items()) if table else pa.array([], 'int64')

split()

Generate tables from splitting list scalars.

Source code in graphique/core.py
524
525
526
527
528
529
530
531
532
533
534
def split(self) -> Iterator[Optional[pa.RecordBatch]]:
    """Generate tables from splitting list scalars."""
    lists = Table.list_fields(self)
    scalars = set(self.schema.names) - lists
    for index, count in enumerate(Table.list_value_length(self).to_pylist()):
        if count is None:
            yield None
        else:
            row = {name: pa.repeat(self[name][index], count) for name in scalars}
            row |= {name: self[name][index].values for name in lists}
            yield pa.RecordBatch.from_pydict(row)

union(*tables)

Return table with union of columns.

Source code in graphique/core.py
309
310
311
312
313
314
def union(*tables: Batch) -> Batch:
    """Return table with union of columns."""
    columns: dict = {}
    for table in tables:
        columns |= Table.columns(table)
    return type(tables[0]).from_pydict(columns)

graphique.core.Nodes

Bases: Declaration

Acero engine declaration.

Provides a Scanner interface with no "oneshot" limitation.

Source code in graphique/core.py
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
class Nodes(ac.Declaration):
    """[Acero](https://arrow.apache.org/docs/python/api/acero.html) engine declaration.

    Provides a `Scanner` interface with no "oneshot" limitation.
    """

    option_map = {
        'table_source': ac.TableSourceNodeOptions,
        'scan': ac.ScanNodeOptions,
        'filter': ac.FilterNodeOptions,
        'project': ac.ProjectNodeOptions,
        'aggregate': ac.AggregateNodeOptions,
        'order_by': ac.OrderByNodeOptions,
        'hashjoin': ac.HashJoinNodeOptions,
    }
    to_batches = ac.Declaration.to_reader  # source compatibility

    def __init__(self, name, *args, inputs=None, **options):
        super().__init__(name, self.option_map[name](*args, **options), inputs)

    def scan(self, columns: Iterable[str]) -> Self:
        """Return projected source node, supporting datasets and tables."""
        if isinstance(self, ds.Dataset):
            expr = self._scan_options.get('filter')
            self = Nodes('scan', self, columns=columns)
            if expr is not None:
                self = self.apply('filter', expr)
        elif isinstance(self, pa.Table):
            self = Nodes('table_source', self)
        elif isinstance(self, pa.RecordBatch):
            self = Nodes('table_source', pa.table(self))
        if isinstance(columns, Mapping):
            return self.apply('project', columns.values(), columns)
        return self.apply('project', map(pc.field, columns))

    @property
    def schema(self) -> pa.Schema:
        """projected schema"""
        with self.to_reader() as reader:
            return reader.schema

    def scanner(self, **options) -> ds.Scanner:
        return ds.Scanner.from_batches(self.to_reader(**options))

    def count_rows(self) -> int:
        """Count matching rows."""
        return self.scanner().count_rows()

    def head(self, num_rows: int, **options) -> pa.Table:
        """Load the first N rows."""
        return self.scanner(**options).head(num_rows)

    def take(self, indices: Iterable[int], **options) -> pa.Table:
        """Select rows by index."""
        return self.scanner(**options).take(indices)

    def apply(self, name: str, *args, **options) -> Self:
        """Add a node by name."""
        return type(self)(name, *args, inputs=[self], **options)

    filter = functools.partialmethod(apply, 'filter')

    def group(self, *names, **aggs: tuple) -> Self:
        """Add `aggregate` node with dictionary support.

        Also supports datasets because aggregation determines the projection.
        """
        aggregates, targets = [], set(names)
        for name, (target, _, _) in aggs.items():
            aggregates.append(aggs[name] + (name,))
            targets.update([target] if isinstance(target, str) else target)
        columns = {name: pc.field(name) for name in targets}
        for name in columns:
            field = self.schema.field(name)
            if pa.types.is_dictionary(field.type):
                columns[name] = columns[name].cast(field.type.value_type)
        return Nodes.scan(self, columns).apply('aggregate', aggregates, names)

schema: pa.Schema property

projected schema

apply(name, *args, **options)

Add a node by name.

Source code in graphique/core.py
600
601
602
def apply(self, name: str, *args, **options) -> Self:
    """Add a node by name."""
    return type(self)(name, *args, inputs=[self], **options)

count_rows()

Count matching rows.

Source code in graphique/core.py
588
589
590
def count_rows(self) -> int:
    """Count matching rows."""
    return self.scanner().count_rows()

group(*names, **aggs)

Add aggregate node with dictionary support.

Also supports datasets because aggregation determines the projection.

Source code in graphique/core.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def group(self, *names, **aggs: tuple) -> Self:
    """Add `aggregate` node with dictionary support.

    Also supports datasets because aggregation determines the projection.
    """
    aggregates, targets = [], set(names)
    for name, (target, _, _) in aggs.items():
        aggregates.append(aggs[name] + (name,))
        targets.update([target] if isinstance(target, str) else target)
    columns = {name: pc.field(name) for name in targets}
    for name in columns:
        field = self.schema.field(name)
        if pa.types.is_dictionary(field.type):
            columns[name] = columns[name].cast(field.type.value_type)
    return Nodes.scan(self, columns).apply('aggregate', aggregates, names)

head(num_rows, **options)

Load the first N rows.

Source code in graphique/core.py
592
593
594
def head(self, num_rows: int, **options) -> pa.Table:
    """Load the first N rows."""
    return self.scanner(**options).head(num_rows)

scan(columns)

Return projected source node, supporting datasets and tables.

Source code in graphique/core.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
def scan(self, columns: Iterable[str]) -> Self:
    """Return projected source node, supporting datasets and tables."""
    if isinstance(self, ds.Dataset):
        expr = self._scan_options.get('filter')
        self = Nodes('scan', self, columns=columns)
        if expr is not None:
            self = self.apply('filter', expr)
    elif isinstance(self, pa.Table):
        self = Nodes('table_source', self)
    elif isinstance(self, pa.RecordBatch):
        self = Nodes('table_source', pa.table(self))
    if isinstance(columns, Mapping):
        return self.apply('project', columns.values(), columns)
    return self.apply('project', map(pc.field, columns))

take(indices, **options)

Select rows by index.

Source code in graphique/core.py
596
597
598
def take(self, indices: Iterable[int], **options) -> pa.Table:
    """Select rows by index."""
    return self.scanner(**options).take(indices)

graphique.interface.Dataset

Source code in graphique/interface.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
@strawberry.interface(description="an arrow dataset, scanner, or table")
class Dataset:
    def __init__(self, source: Source):
        self.source = source

    def references(self, info: Info, level: int = 0) -> set:
        """Return set of every possible future column reference."""
        fields = info.selected_fields
        for _ in range(level):
            fields = itertools.chain(*[field.selections for field in fields])
        return set(itertools.chain(*map(references, fields))) & set(self.schema().names)

    def select(self, info: Info) -> Source:
        """Return source with only the columns necessary to proceed."""
        names = list(self.references(info))
        if len(names) >= len(self.schema().names):
            return self.source
        if isinstance(self.source, ds.Scanner):
            schema = self.source.projected_schema
            return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names)
        if isinstance(self.source, pa.Table):
            return self.source.select(names)
        return Nodes.scan(self.source, names)

    def to_table(self, info: Info, length: Optional[int] = None) -> pa.Table:
        """Return table with only the rows and columns necessary to proceed."""
        source = self.select(info)
        if isinstance(source, pa.Table):
            return source
        if length is None:
            return self.add_metric(info, source.to_table(), mode='read')
        return self.add_metric(info, source.head(length), mode='head')

    @classmethod
    @no_type_check
    def resolve_reference(cls, info: Info, **keys) -> Self:
        """Return table from federated keys."""
        self = getattr(info.root_value, cls.field)
        queries = {name: Filter(eq=[keys[name]]) for name in keys}
        return self.filter(info, **queries)

    def columns(self, info: Info) -> dict:
        """fields for each column"""
        table = self.to_table(info)
        return {name: Column.cast(table[name]) for name in table.schema.names}

    def row(self, info: Info, index: int = 0) -> dict:
        """Return scalar values at index."""
        table = self.to_table(info, index + 1 if index >= 0 else None)
        row = {}
        for name in table.schema.names:
            scalar = table[name][index]
            columnar = isinstance(scalar, pa.ListScalar)
            row[name] = Column.fromscalar(scalar) if columnar else scalar.as_py()
        return row

    def filter(self, info: Info, **queries: Filter) -> Self:
        """Return table with rows which match all queries.

        See `scan(filter: ...)` for more advanced queries. Additional feature: sorted tables
        support binary search
        """
        source = self.source
        prev = info.path.prev
        search = isinstance(source, pa.Table) and (prev is None or prev.typename == 'Query')
        for name in self.schema().index if search else []:
            assert not source[name].null_count, f"search requires non-null column: {name}"
            query = dict(queries.pop(name))
            if 'eq' in query:
                source = T.is_in(source, name, *query['eq'])
            if 'ne' in query:
                source = T.not_equal(source, name, query['ne'])
            lower, upper = query.get('gt'), query.get('lt')
            includes = {'include_lower': False, 'include_upper': False}
            if 'ge' in query and (lower is None or query['ge'] > lower):
                lower, includes['include_lower'] = query['ge'], True
            if 'le' in query and (upper is None or query['le'] > upper):
                upper, includes['include_upper'] = query['le'], True
            if {lower, upper} != {None}:
                source = T.range(source, name, lower, upper, **includes)
            if len(query.pop('eq', [])) != 1 or query:
                break
        return type(self)(source).scan(info, filter=Expression.from_query(**queries))

    @doc_field
    def type(self) -> str:
        """[arrow type](https://arrow.apache.org/docs/python/api/dataset.html#classes)"""
        return type(self.source).__name__

    @doc_field
    def schema(self) -> Schema:
        """dataset schema"""
        source = self.source
        schema = source.projected_schema if isinstance(source, ds.Scanner) else source.schema
        partitioning = getattr(source, 'partitioning', None)
        index = (schema.pandas_metadata or {}).get('index_columns', [])
        return Schema(
            names=schema.names,
            types=schema.types,
            partitioning=partitioning.schema.names if partitioning else [],
            index=[name for name in index if isinstance(name, str)],
        )  # type: ignore

    @doc_field
    def optional(self) -> Optional[Self]:
        """Nullable field to stop error propagation, enabling partial query results.

        Will be replaced by client controlled nullability.
        """
        return self

    @staticmethod
    def add_metric(info: Info, table: pa.Table, **data):
        """Add memory usage and other metrics to context with path info."""
        path = tuple(get_path_from_info(info))
        info.context.setdefault('metrics', {})[path] = dict(data, memory=T.size(table))
        return table

    @doc_field
    def length(self) -> Long:
        """number of rows"""
        return len(self.source) if isinstance(self.source, Sized) else self.source.count_rows()

    @doc_field
    def any(self, info: Info, length: Long = 1) -> bool:
        """Return whether there are at least `length` rows.

        May be significantly faster than `length` for out-of-core data.
        """
        table = self.to_table(info, length)
        return len(table) >= length

    @doc_field
    def size(self) -> Optional[Long]:
        """buffer size in bytes; null if table is not loaded"""
        return getattr(self.source, 'nbytes', None)

    @doc_field(
        name="column name(s); multiple names access nested struct fields",
        cast=f"cast array to {links.type}",
        safe="check for conversion errors on cast",
    )
    def column(
        self, info: Info, name: list[str], cast: str = '', safe: bool = True
    ) -> Optional[Column]:
        """Return column of any type by name.

        This is typically only needed for aliased or casted columns.
        If the column is in the schema, `columns` can be used instead.
        """
        if isinstance(self.source, pa.Table) and len(name) == 1:
            column = self.source.column(*name)
            return Column.cast(column.cast(cast, safe) if cast else column)
        column = Projection(alias='_', name=name, cast=cast, safe=safe)  # type: ignore
        source = self.scan(info, Expression(), [column]).source
        return Column.cast(*(source if isinstance(source, pa.Table) else source.to_table()))

    @doc_field(
        offset="number of rows to skip; negative value skips from the end",
        length="maximum number of rows to return",
        reverse="reverse order after slicing; forces a copy",
    )
    def slice(
        self, info: Info, offset: Long = 0, length: Optional[Long] = None, reverse: bool = False
    ) -> Self:
        """Return zero-copy slice of table.

        Can also be sued to force loading a dataset.
        """
        table = self.to_table(info, length and (offset + length if offset >= 0 else None))
        table = table[offset:][:length]  # `slice` bug: ARROW-15412
        return type(self)(table[::-1] if reverse else table)

    @doc_field(
        by="column names; empty will aggregate into a single row table",
        counts="optionally include counts in an aliased column",
        ordered="optionally disable parallelization to maintain ordering",
        aggregate="aggregation functions applied to other columns",
    )
    def group(
        self,
        info: Info,
        by: list[str] = [],
        counts: str = '',
        ordered: bool = False,
        aggregate: HashAggregates = {},  # type: ignore
    ) -> Self:
        """Return table grouped by columns.

        See `column` for accessing any column which has changed type. See `tables` to split on any
        aggregated list columns.
        """
        if not any(aggregate.keys()):
            fragments = T.fragments(self.source, *by, counts=counts)
            if set(fragments.schema.names) >= set(by):
                return type(self)(fragments)
        prefix = 'hash_' if by else ''
        aggs: dict = {counts: ([], prefix + 'count_all', None)} if counts else {}
        for func, values in dict(aggregate).items():
            ordered = ordered or func in Agg.ordered
            for agg in values:
                aggs[agg.alias] = (agg.name, prefix + func, agg.func_options(func))
        source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source
        source = Nodes.group(source, *by, **aggs)
        if ordered:
            source = self.add_metric(info, source.to_table(use_threads=False), mode='group')
        return type(self)(source)

    @doc_field(
        by="column names",
        split="optional predicates to split on; scalars are compared to pairwise difference",
        counts="optionally include counts in an aliased column",
    )
    @no_type_check
    def runs(
        self, info: Info, by: list[str] = [], split: list[Diff] = [], counts: str = ''
    ) -> Self:
        """Return table grouped by pairwise differences.

        Differs from `group` by relying on adjacency, and is typically faster. Other columns are
        transformed into list columns. See `column` and `tables` to further access lists.
        """
        table = self.to_table(info)
        predicates = {}
        for diff in map(dict, split):
            name = diff.pop('name')
            ((func, value),) = diff.items()
            if pa.types.is_timestamp(table.field(name).type):
                value = timedelta(seconds=value)
            predicates[name] = (getattr(pc, func), value)[: 1 if value is None else 2]
        table, counts_ = T.runs(table, *by, **predicates)
        return type(self)(table.append_column(counts, counts_) if counts else table)

    @doc_field(
        by="column names; prefix with `-` for descending order",
        length="maximum number of rows to return; may be significantly faster but is unstable",
        null_placement="where nulls in input should be sorted; incompatible with `length`",
    )
    def sort(
        self,
        info: Info,
        by: list[str],
        length: Optional[Long] = None,
        null_placement: str = 'at_end',
    ) -> Self:
        """Return table slice sorted by specified columns.

        Optimized for length == 1; matches min or max values.
        """
        kwargs = dict(length=length, null_placement=null_placement)
        if isinstance(self.source, pa.Table) or length is None:
            table = self.to_table(info)
        else:
            expr, by = T.rank_keys(self.source, length, *by, dense=False)
            if expr is not None:
                self = type(self)(self.source.filter(expr))
            source = self.select(info)
            if not by:
                return type(self)(self.add_metric(info, source.head(length), mode='head'))
            table = T.map_batch(source, T.sort, *by, **kwargs)
            self.add_metric(info, table, mode='batch')
        return type(self)(T.sort(table, *by, **kwargs))  # type: ignore

    @doc_field(
        by="column names; prefix with `-` for descending order",
        max="maximum dense rank to select; optimized for == 1 (min or max)",
    )
    def rank(self, info: Info, by: list[str], max: int = 1) -> Self:
        """Return table selected by maximum dense rank."""
        source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source
        expr, by = T.rank_keys(source, max, *by)
        if expr is not None:
            source = source.filter(expr)
        return type(self)(T.rank(source, max, *by) if by else source)

    @staticmethod
    def apply_list(table: Batch, list_: ListFunction) -> Batch:
        expr = list_.filter.to_arrow() if list_.filter else None
        if expr is not None:
            table = T.filter_list(table, expr)
        if list_.rank:
            table = T.map_list(table, T.rank, list_.rank.max, *list_.rank.by)
        if list_.sort:
            table = T.map_list(table, T.sort, *list_.sort.by, length=list_.sort.length)
        columns = {}
        for func, field in dict(list_).items():
            columns[field.alias] = getattr(ListChunk, func)(table[field.name], **field.options)
        return T.union(table, pa.RecordBatch.from_pydict(columns))

    @doc_field
    @no_type_check
    def apply(
        self,
        info: Info,
        cumulative_max: doc_argument(list[Cumulative], func=pc.cumulative_max) = [],
        cumulative_mean: doc_argument(list[Cumulative], func=pc.cumulative_mean) = [],
        cumulative_min: doc_argument(list[Cumulative], func=pc.cumulative_min) = [],
        cumulative_prod: doc_argument(list[Cumulative], func=pc.cumulative_prod) = [],
        cumulative_sum: doc_argument(list[Cumulative], func=pc.cumulative_sum) = [],
        fill_null_backward: doc_argument(list[Field], func=pc.fill_null_backward) = [],
        fill_null_forward: doc_argument(list[Field], func=pc.fill_null_forward) = [],
        pairwise_diff: doc_argument(list[Pairwise], func=pc.pairwise_diff) = [],
        rank: doc_argument(list[Rank], func=pc.rank) = [],
        list_: Annotated[
            ListFunction,
            strawberry.argument(name='list', description="functions for list arrays."),
        ] = {},
    ) -> Self:
        """Return view of table with vector functions applied across columns.

        Applied functions load arrays into memory as needed. See `scan` for scalar functions,
        which do not require loading.
        """
        table = T.map_batch(self.select(info), self.apply_list, list_)
        self.add_metric(info, table, mode='batch')
        columns = {}
        funcs = pc.cumulative_max, pc.cumulative_mean, pc.cumulative_min, pc.cumulative_prod
        funcs += pc.cumulative_sum, C.fill_null_backward, C.fill_null_forward, C.pairwise_diff
        funcs += (pc.rank,)
        for func in funcs:
            for field in locals()[func.__name__]:
                callable = func
                if field.options.pop('checked', False):
                    callable = getattr(pc, func.__name__ + '_checked')
                columns[field.alias] = callable(table[field.name], **field.options)
        return type(self)(T.union(table, pa.table(columns)))

    @doc_field
    def flatten(self, info: Info, indices: str = '') -> Self:
        """Return table with list arrays flattened.

        At least one list column must be referenced, and all list columns must have the same lengths.
        """
        table = pa.Table.from_batches(T.flatten(self.select(info), indices))
        return type(self)(self.add_metric(info, table, mode='batch'))

    @doc_field
    def tables(self, info: Info) -> list[Optional[Self]]:  # type: ignore
        """Return a list of tables by splitting list columns.

        At least one list column must be referenced, and all list columns must have the same lengths.
        """
        for batch in self.select(info).to_batches():
            for row in T.split(batch):
                yield None if row is None else type(self)(pa.Table.from_batches([row]))

    @doc_field
    def aggregate(
        self,
        info: Info,
        approximate_median: doc_argument(list[ScalarAggregate], func=pc.approximate_median) = [],
        count: doc_argument(list[CountAggregate], func=pc.count) = [],
        count_distinct: doc_argument(list[CountAggregate], func=pc.count_distinct) = [],
        distinct: Annotated[
            list[CountAggregate],
            strawberry.argument(description="distinct values within each scalar"),
        ] = [],
        first: doc_argument(list[Field], func=ListChunk.first) = [],
        last: doc_argument(list[Field], func=ListChunk.last) = [],
        max: doc_argument(list[ScalarAggregate], func=pc.max) = [],
        mean: doc_argument(list[ScalarAggregate], func=pc.mean) = [],
        min: doc_argument(list[ScalarAggregate], func=pc.min) = [],
        product: doc_argument(list[ScalarAggregate], func=pc.product) = [],
        stddev: doc_argument(list[VarianceAggregate], func=pc.stddev) = [],
        sum: doc_argument(list[ScalarAggregate], func=pc.sum) = [],
        tdigest: doc_argument(list[TDigestAggregate], func=pc.tdigest) = [],
        variance: doc_argument(list[VarianceAggregate], func=pc.variance) = [],
    ) -> Self:
        """Return table with scalar aggregate functions applied to list columns."""
        table = self.to_table(info)
        columns = T.columns(table)
        agg_fields: dict = collections.defaultdict(dict)
        keys: tuple = 'approximate_median', 'count', 'count_distinct', 'distinct', 'first', 'last'
        keys += 'max', 'mean', 'min', 'product', 'stddev', 'sum', 'tdigest', 'variance'
        for key in keys:
            func = getattr(ListChunk, key, None)
            for agg in locals()[key]:
                if func is None or key == 'sum':  # `sum` is a method on `Array``
                    agg_fields[agg.name][key] = agg
                else:
                    columns[agg.alias] = func(table[agg.name], **agg.options)
        for name, aggs in agg_fields.items():
            funcs = {key: agg.func_options(key) for key, agg in aggs.items()}
            batch = ListChunk.aggregate(table[name], **funcs)
            columns.update(zip([agg.alias for agg in aggs.values()], batch))
        return type(self)(pa.table(columns))

    aggregate.deprecation_reason = ListFunction.deprecation

    @doc_field(filter="selected rows", columns="projected columns")
    def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = []) -> Self:  # type: ignore
        """Select rows and project columns without memory usage."""
        expr = filter.to_arrow()
        projection = {name: pc.field(name) for name in self.references(info, level=1)}
        projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns}
        if '' in projection:
            raise ValueError(f"projected columns need a name or alias: {projection['']}")
        if isinstance(self.source, ds.Scanner):
            options = dict(schema=self.source.projected_schema, filter=expr, columns=projection)
            scanner = ds.Scanner.from_batches(self.source.to_batches(), **options)
            return type(self)(self.add_metric(info, scanner.to_table(), mode='batch'))
        source = self.source if expr is None else self.source.filter(expr)
        return type(self)(Nodes.scan(source, projection) if columns else source)

    @doc_field(
        right="name of right table; must be on root Query type",
        keys="column names used as keys on the left side",
        right_keys="column names used as keys on the right side; defaults to left side.",
        join_type="the kind of join: 'left semi', 'right semi', 'left anti', 'right anti', 'inner', 'left outer', 'right outer', 'full outer'",
        left_suffix="add suffix to left column names; for preventing collisions",
        right_suffix="add suffix to right column names; for preventing collisions.",
        coalesce_keys="omit duplicate keys",
    )
    def join(
        self,
        info: Info,
        right: str,
        keys: list[str],
        right_keys: Optional[list[str]] = None,
        join_type: str = 'left outer',
        left_suffix: str = '',
        right_suffix: str = '',
        coalesce_keys: bool = True,
    ) -> Self:
        """Provisional: [join](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#pyarrow.dataset.Dataset.join) this table with another table on the root Query type."""
        left, right = (
            root.source if isinstance(root.source, ds.Dataset) else root.to_table(info)
            for root in (self, getattr(info.root_value, right))
        )
        table = left.join(
            right,
            keys=keys,
            right_keys=right_keys,
            join_type=join_type,
            left_suffix=left_suffix,
            right_suffix=right_suffix,
            coalesce_keys=coalesce_keys,
        )
        return type(self)(table)

    join.directives = [provisional()]

    @doc_field
    def take(self, info: Info, indices: list[Long]) -> Self:
        """Select rows from indices."""
        table = self.select(info).take(indices)
        return type(self)(self.add_metric(info, table, mode='take'))

    @doc_field
    def drop_null(self, info: Info) -> Self:
        """Remove missing values from referenced columns in the table."""
        if isinstance(self.source, pa.Table):
            return type(self)(pc.drop_null(self.to_table(info)))
        table = T.map_batch(self.select(info), pc.drop_null)
        return type(self)(self.add_metric(info, table, mode='batch'))

add_metric(info, table, **data) staticmethod

Add memory usage and other metrics to context with path info.

Source code in graphique/interface.py
173
174
175
176
177
178
@staticmethod
def add_metric(info: Info, table: pa.Table, **data):
    """Add memory usage and other metrics to context with path info."""
    path = tuple(get_path_from_info(info))
    info.context.setdefault('metrics', {})[path] = dict(data, memory=T.size(table))
    return table

aggregate(info, approximate_median=[], count=[], count_distinct=[], distinct=[], first=[], last=[], max=[], mean=[], min=[], product=[], stddev=[], sum=[], tdigest=[], variance=[])

Return table with scalar aggregate functions applied to list columns.

Source code in graphique/interface.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
@doc_field
def aggregate(
    self,
    info: Info,
    approximate_median: doc_argument(list[ScalarAggregate], func=pc.approximate_median) = [],
    count: doc_argument(list[CountAggregate], func=pc.count) = [],
    count_distinct: doc_argument(list[CountAggregate], func=pc.count_distinct) = [],
    distinct: Annotated[
        list[CountAggregate],
        strawberry.argument(description="distinct values within each scalar"),
    ] = [],
    first: doc_argument(list[Field], func=ListChunk.first) = [],
    last: doc_argument(list[Field], func=ListChunk.last) = [],
    max: doc_argument(list[ScalarAggregate], func=pc.max) = [],
    mean: doc_argument(list[ScalarAggregate], func=pc.mean) = [],
    min: doc_argument(list[ScalarAggregate], func=pc.min) = [],
    product: doc_argument(list[ScalarAggregate], func=pc.product) = [],
    stddev: doc_argument(list[VarianceAggregate], func=pc.stddev) = [],
    sum: doc_argument(list[ScalarAggregate], func=pc.sum) = [],
    tdigest: doc_argument(list[TDigestAggregate], func=pc.tdigest) = [],
    variance: doc_argument(list[VarianceAggregate], func=pc.variance) = [],
) -> Self:
    """Return table with scalar aggregate functions applied to list columns."""
    table = self.to_table(info)
    columns = T.columns(table)
    agg_fields: dict = collections.defaultdict(dict)
    keys: tuple = 'approximate_median', 'count', 'count_distinct', 'distinct', 'first', 'last'
    keys += 'max', 'mean', 'min', 'product', 'stddev', 'sum', 'tdigest', 'variance'
    for key in keys:
        func = getattr(ListChunk, key, None)
        for agg in locals()[key]:
            if func is None or key == 'sum':  # `sum` is a method on `Array``
                agg_fields[agg.name][key] = agg
            else:
                columns[agg.alias] = func(table[agg.name], **agg.options)
    for name, aggs in agg_fields.items():
        funcs = {key: agg.func_options(key) for key, agg in aggs.items()}
        batch = ListChunk.aggregate(table[name], **funcs)
        columns.update(zip([agg.alias for agg in aggs.values()], batch))
    return type(self)(pa.table(columns))

any(info, length=1)

Return whether there are at least length rows.

May be significantly faster than length for out-of-core data.

Source code in graphique/interface.py
185
186
187
188
189
190
191
192
@doc_field
def any(self, info: Info, length: Long = 1) -> bool:
    """Return whether there are at least `length` rows.

    May be significantly faster than `length` for out-of-core data.
    """
    table = self.to_table(info, length)
    return len(table) >= length

apply(info, cumulative_max=[], cumulative_mean=[], cumulative_min=[], cumulative_prod=[], cumulative_sum=[], fill_null_backward=[], fill_null_forward=[], pairwise_diff=[], rank=[], list_={})

Return view of table with vector functions applied across columns.

Applied functions load arrays into memory as needed. See scan for scalar functions, which do not require loading.

Source code in graphique/interface.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@doc_field
@no_type_check
def apply(
    self,
    info: Info,
    cumulative_max: doc_argument(list[Cumulative], func=pc.cumulative_max) = [],
    cumulative_mean: doc_argument(list[Cumulative], func=pc.cumulative_mean) = [],
    cumulative_min: doc_argument(list[Cumulative], func=pc.cumulative_min) = [],
    cumulative_prod: doc_argument(list[Cumulative], func=pc.cumulative_prod) = [],
    cumulative_sum: doc_argument(list[Cumulative], func=pc.cumulative_sum) = [],
    fill_null_backward: doc_argument(list[Field], func=pc.fill_null_backward) = [],
    fill_null_forward: doc_argument(list[Field], func=pc.fill_null_forward) = [],
    pairwise_diff: doc_argument(list[Pairwise], func=pc.pairwise_diff) = [],
    rank: doc_argument(list[Rank], func=pc.rank) = [],
    list_: Annotated[
        ListFunction,
        strawberry.argument(name='list', description="functions for list arrays."),
    ] = {},
) -> Self:
    """Return view of table with vector functions applied across columns.

    Applied functions load arrays into memory as needed. See `scan` for scalar functions,
    which do not require loading.
    """
    table = T.map_batch(self.select(info), self.apply_list, list_)
    self.add_metric(info, table, mode='batch')
    columns = {}
    funcs = pc.cumulative_max, pc.cumulative_mean, pc.cumulative_min, pc.cumulative_prod
    funcs += pc.cumulative_sum, C.fill_null_backward, C.fill_null_forward, C.pairwise_diff
    funcs += (pc.rank,)
    for func in funcs:
        for field in locals()[func.__name__]:
            callable = func
            if field.options.pop('checked', False):
                callable = getattr(pc, func.__name__ + '_checked')
            columns[field.alias] = callable(table[field.name], **field.options)
    return type(self)(T.union(table, pa.table(columns)))

column(info, name, cast='', safe=True)

Return column of any type by name.

This is typically only needed for aliased or casted columns. If the column is in the schema, columns can be used instead.

Source code in graphique/interface.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@doc_field(
    name="column name(s); multiple names access nested struct fields",
    cast=f"cast array to {links.type}",
    safe="check for conversion errors on cast",
)
def column(
    self, info: Info, name: list[str], cast: str = '', safe: bool = True
) -> Optional[Column]:
    """Return column of any type by name.

    This is typically only needed for aliased or casted columns.
    If the column is in the schema, `columns` can be used instead.
    """
    if isinstance(self.source, pa.Table) and len(name) == 1:
        column = self.source.column(*name)
        return Column.cast(column.cast(cast, safe) if cast else column)
    column = Projection(alias='_', name=name, cast=cast, safe=safe)  # type: ignore
    source = self.scan(info, Expression(), [column]).source
    return Column.cast(*(source if isinstance(source, pa.Table) else source.to_table()))

columns(info)

fields for each column

Source code in graphique/interface.py
103
104
105
106
def columns(self, info: Info) -> dict:
    """fields for each column"""
    table = self.to_table(info)
    return {name: Column.cast(table[name]) for name in table.schema.names}

drop_null(info)

Remove missing values from referenced columns in the table.

Source code in graphique/interface.py
510
511
512
513
514
515
516
@doc_field
def drop_null(self, info: Info) -> Self:
    """Remove missing values from referenced columns in the table."""
    if isinstance(self.source, pa.Table):
        return type(self)(pc.drop_null(self.to_table(info)))
    table = T.map_batch(self.select(info), pc.drop_null)
    return type(self)(self.add_metric(info, table, mode='batch'))

filter(info, **queries)

Return table with rows which match all queries.

See scan(filter: ...) for more advanced queries. Additional feature: sorted tables support binary search

Source code in graphique/interface.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def filter(self, info: Info, **queries: Filter) -> Self:
    """Return table with rows which match all queries.

    See `scan(filter: ...)` for more advanced queries. Additional feature: sorted tables
    support binary search
    """
    source = self.source
    prev = info.path.prev
    search = isinstance(source, pa.Table) and (prev is None or prev.typename == 'Query')
    for name in self.schema().index if search else []:
        assert not source[name].null_count, f"search requires non-null column: {name}"
        query = dict(queries.pop(name))
        if 'eq' in query:
            source = T.is_in(source, name, *query['eq'])
        if 'ne' in query:
            source = T.not_equal(source, name, query['ne'])
        lower, upper = query.get('gt'), query.get('lt')
        includes = {'include_lower': False, 'include_upper': False}
        if 'ge' in query and (lower is None or query['ge'] > lower):
            lower, includes['include_lower'] = query['ge'], True
        if 'le' in query and (upper is None or query['le'] > upper):
            upper, includes['include_upper'] = query['le'], True
        if {lower, upper} != {None}:
            source = T.range(source, name, lower, upper, **includes)
        if len(query.pop('eq', [])) != 1 or query:
            break
    return type(self)(source).scan(info, filter=Expression.from_query(**queries))

flatten(info, indices='')

Return table with list arrays flattened.

At least one list column must be referenced, and all list columns must have the same lengths.

Source code in graphique/interface.py
389
390
391
392
393
394
395
396
@doc_field
def flatten(self, info: Info, indices: str = '') -> Self:
    """Return table with list arrays flattened.

    At least one list column must be referenced, and all list columns must have the same lengths.
    """
    table = pa.Table.from_batches(T.flatten(self.select(info), indices))
    return type(self)(self.add_metric(info, table, mode='batch'))

group(info, by=[], counts='', ordered=False, aggregate={})

Return table grouped by columns.

See column for accessing any column which has changed type. See tables to split on any aggregated list columns.

Source code in graphique/interface.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
@doc_field(
    by="column names; empty will aggregate into a single row table",
    counts="optionally include counts in an aliased column",
    ordered="optionally disable parallelization to maintain ordering",
    aggregate="aggregation functions applied to other columns",
)
def group(
    self,
    info: Info,
    by: list[str] = [],
    counts: str = '',
    ordered: bool = False,
    aggregate: HashAggregates = {},  # type: ignore
) -> Self:
    """Return table grouped by columns.

    See `column` for accessing any column which has changed type. See `tables` to split on any
    aggregated list columns.
    """
    if not any(aggregate.keys()):
        fragments = T.fragments(self.source, *by, counts=counts)
        if set(fragments.schema.names) >= set(by):
            return type(self)(fragments)
    prefix = 'hash_' if by else ''
    aggs: dict = {counts: ([], prefix + 'count_all', None)} if counts else {}
    for func, values in dict(aggregate).items():
        ordered = ordered or func in Agg.ordered
        for agg in values:
            aggs[agg.alias] = (agg.name, prefix + func, agg.func_options(func))
    source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source
    source = Nodes.group(source, *by, **aggs)
    if ordered:
        source = self.add_metric(info, source.to_table(use_threads=False), mode='group')
    return type(self)(source)

join(info, right, keys, right_keys=None, join_type='left outer', left_suffix='', right_suffix='', coalesce_keys=True)

Provisional: join this table with another table on the root Query type.

Source code in graphique/interface.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
@doc_field(
    right="name of right table; must be on root Query type",
    keys="column names used as keys on the left side",
    right_keys="column names used as keys on the right side; defaults to left side.",
    join_type="the kind of join: 'left semi', 'right semi', 'left anti', 'right anti', 'inner', 'left outer', 'right outer', 'full outer'",
    left_suffix="add suffix to left column names; for preventing collisions",
    right_suffix="add suffix to right column names; for preventing collisions.",
    coalesce_keys="omit duplicate keys",
)
def join(
    self,
    info: Info,
    right: str,
    keys: list[str],
    right_keys: Optional[list[str]] = None,
    join_type: str = 'left outer',
    left_suffix: str = '',
    right_suffix: str = '',
    coalesce_keys: bool = True,
) -> Self:
    """Provisional: [join](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#pyarrow.dataset.Dataset.join) this table with another table on the root Query type."""
    left, right = (
        root.source if isinstance(root.source, ds.Dataset) else root.to_table(info)
        for root in (self, getattr(info.root_value, right))
    )
    table = left.join(
        right,
        keys=keys,
        right_keys=right_keys,
        join_type=join_type,
        left_suffix=left_suffix,
        right_suffix=right_suffix,
        coalesce_keys=coalesce_keys,
    )
    return type(self)(table)

length()

number of rows

Source code in graphique/interface.py
180
181
182
183
@doc_field
def length(self) -> Long:
    """number of rows"""
    return len(self.source) if isinstance(self.source, Sized) else self.source.count_rows()

optional()

Nullable field to stop error propagation, enabling partial query results.

Will be replaced by client controlled nullability.

Source code in graphique/interface.py
165
166
167
168
169
170
171
@doc_field
def optional(self) -> Optional[Self]:
    """Nullable field to stop error propagation, enabling partial query results.

    Will be replaced by client controlled nullability.
    """
    return self

rank(info, by, max=1)

Return table selected by maximum dense rank.

Source code in graphique/interface.py
325
326
327
328
329
330
331
332
333
334
335
@doc_field(
    by="column names; prefix with `-` for descending order",
    max="maximum dense rank to select; optimized for == 1 (min or max)",
)
def rank(self, info: Info, by: list[str], max: int = 1) -> Self:
    """Return table selected by maximum dense rank."""
    source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source
    expr, by = T.rank_keys(source, max, *by)
    if expr is not None:
        source = source.filter(expr)
    return type(self)(T.rank(source, max, *by) if by else source)

references(info, level=0)

Return set of every possible future column reference.

Source code in graphique/interface.py
67
68
69
70
71
72
def references(self, info: Info, level: int = 0) -> set:
    """Return set of every possible future column reference."""
    fields = info.selected_fields
    for _ in range(level):
        fields = itertools.chain(*[field.selections for field in fields])
    return set(itertools.chain(*map(references, fields))) & set(self.schema().names)

resolve_reference(info, **keys) classmethod

Return table from federated keys.

Source code in graphique/interface.py
 95
 96
 97
 98
 99
100
101
@classmethod
@no_type_check
def resolve_reference(cls, info: Info, **keys) -> Self:
    """Return table from federated keys."""
    self = getattr(info.root_value, cls.field)
    queries = {name: Filter(eq=[keys[name]]) for name in keys}
    return self.filter(info, **queries)

row(info, index=0)

Return scalar values at index.

Source code in graphique/interface.py
108
109
110
111
112
113
114
115
116
def row(self, info: Info, index: int = 0) -> dict:
    """Return scalar values at index."""
    table = self.to_table(info, index + 1 if index >= 0 else None)
    row = {}
    for name in table.schema.names:
        scalar = table[name][index]
        columnar = isinstance(scalar, pa.ListScalar)
        row[name] = Column.fromscalar(scalar) if columnar else scalar.as_py()
    return row

runs(info, by=[], split=[], counts='')

Return table grouped by pairwise differences.

Differs from group by relying on adjacency, and is typically faster. Other columns are transformed into list columns. See column and tables to further access lists.

Source code in graphique/interface.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
@doc_field(
    by="column names",
    split="optional predicates to split on; scalars are compared to pairwise difference",
    counts="optionally include counts in an aliased column",
)
@no_type_check
def runs(
    self, info: Info, by: list[str] = [], split: list[Diff] = [], counts: str = ''
) -> Self:
    """Return table grouped by pairwise differences.

    Differs from `group` by relying on adjacency, and is typically faster. Other columns are
    transformed into list columns. See `column` and `tables` to further access lists.
    """
    table = self.to_table(info)
    predicates = {}
    for diff in map(dict, split):
        name = diff.pop('name')
        ((func, value),) = diff.items()
        if pa.types.is_timestamp(table.field(name).type):
            value = timedelta(seconds=value)
        predicates[name] = (getattr(pc, func), value)[: 1 if value is None else 2]
    table, counts_ = T.runs(table, *by, **predicates)
    return type(self)(table.append_column(counts, counts_) if counts else table)

scan(info, filter={}, columns=[])

Select rows and project columns without memory usage.

Source code in graphique/interface.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
@doc_field(filter="selected rows", columns="projected columns")
def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = []) -> Self:  # type: ignore
    """Select rows and project columns without memory usage."""
    expr = filter.to_arrow()
    projection = {name: pc.field(name) for name in self.references(info, level=1)}
    projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns}
    if '' in projection:
        raise ValueError(f"projected columns need a name or alias: {projection['']}")
    if isinstance(self.source, ds.Scanner):
        options = dict(schema=self.source.projected_schema, filter=expr, columns=projection)
        scanner = ds.Scanner.from_batches(self.source.to_batches(), **options)
        return type(self)(self.add_metric(info, scanner.to_table(), mode='batch'))
    source = self.source if expr is None else self.source.filter(expr)
    return type(self)(Nodes.scan(source, projection) if columns else source)

schema()

dataset schema

Source code in graphique/interface.py
151
152
153
154
155
156
157
158
159
160
161
162
163
@doc_field
def schema(self) -> Schema:
    """dataset schema"""
    source = self.source
    schema = source.projected_schema if isinstance(source, ds.Scanner) else source.schema
    partitioning = getattr(source, 'partitioning', None)
    index = (schema.pandas_metadata or {}).get('index_columns', [])
    return Schema(
        names=schema.names,
        types=schema.types,
        partitioning=partitioning.schema.names if partitioning else [],
        index=[name for name in index if isinstance(name, str)],
    )  # type: ignore

select(info)

Return source with only the columns necessary to proceed.

Source code in graphique/interface.py
74
75
76
77
78
79
80
81
82
83
84
def select(self, info: Info) -> Source:
    """Return source with only the columns necessary to proceed."""
    names = list(self.references(info))
    if len(names) >= len(self.schema().names):
        return self.source
    if isinstance(self.source, ds.Scanner):
        schema = self.source.projected_schema
        return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names)
    if isinstance(self.source, pa.Table):
        return self.source.select(names)
    return Nodes.scan(self.source, names)

size()

buffer size in bytes; null if table is not loaded

Source code in graphique/interface.py
194
195
196
197
@doc_field
def size(self) -> Optional[Long]:
    """buffer size in bytes; null if table is not loaded"""
    return getattr(self.source, 'nbytes', None)

slice(info, offset=0, length=None, reverse=False)

Return zero-copy slice of table.

Can also be sued to force loading a dataset.

Source code in graphique/interface.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
@doc_field(
    offset="number of rows to skip; negative value skips from the end",
    length="maximum number of rows to return",
    reverse="reverse order after slicing; forces a copy",
)
def slice(
    self, info: Info, offset: Long = 0, length: Optional[Long] = None, reverse: bool = False
) -> Self:
    """Return zero-copy slice of table.

    Can also be sued to force loading a dataset.
    """
    table = self.to_table(info, length and (offset + length if offset >= 0 else None))
    table = table[offset:][:length]  # `slice` bug: ARROW-15412
    return type(self)(table[::-1] if reverse else table)

sort(info, by, length=None, null_placement='at_end')

Return table slice sorted by specified columns.

Optimized for length == 1; matches min or max values.

Source code in graphique/interface.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
@doc_field(
    by="column names; prefix with `-` for descending order",
    length="maximum number of rows to return; may be significantly faster but is unstable",
    null_placement="where nulls in input should be sorted; incompatible with `length`",
)
def sort(
    self,
    info: Info,
    by: list[str],
    length: Optional[Long] = None,
    null_placement: str = 'at_end',
) -> Self:
    """Return table slice sorted by specified columns.

    Optimized for length == 1; matches min or max values.
    """
    kwargs = dict(length=length, null_placement=null_placement)
    if isinstance(self.source, pa.Table) or length is None:
        table = self.to_table(info)
    else:
        expr, by = T.rank_keys(self.source, length, *by, dense=False)
        if expr is not None:
            self = type(self)(self.source.filter(expr))
        source = self.select(info)
        if not by:
            return type(self)(self.add_metric(info, source.head(length), mode='head'))
        table = T.map_batch(source, T.sort, *by, **kwargs)
        self.add_metric(info, table, mode='batch')
    return type(self)(T.sort(table, *by, **kwargs))  # type: ignore

tables(info)

Return a list of tables by splitting list columns.

At least one list column must be referenced, and all list columns must have the same lengths.

Source code in graphique/interface.py
398
399
400
401
402
403
404
405
406
@doc_field
def tables(self, info: Info) -> list[Optional[Self]]:  # type: ignore
    """Return a list of tables by splitting list columns.

    At least one list column must be referenced, and all list columns must have the same lengths.
    """
    for batch in self.select(info).to_batches():
        for row in T.split(batch):
            yield None if row is None else type(self)(pa.Table.from_batches([row]))

take(info, indices)

Select rows from indices.

Source code in graphique/interface.py
504
505
506
507
508
@doc_field
def take(self, info: Info, indices: list[Long]) -> Self:
    """Select rows from indices."""
    table = self.select(info).take(indices)
    return type(self)(self.add_metric(info, table, mode='take'))

to_table(info, length=None)

Return table with only the rows and columns necessary to proceed.

Source code in graphique/interface.py
86
87
88
89
90
91
92
93
def to_table(self, info: Info, length: Optional[int] = None) -> pa.Table:
    """Return table with only the rows and columns necessary to proceed."""
    source = self.select(info)
    if isinstance(source, pa.Table):
        return source
    if length is None:
        return self.add_metric(info, source.to_table(), mode='read')
    return self.add_metric(info, source.head(length), mode='head')

type()

arrow type

Source code in graphique/interface.py
146
147
148
149
@doc_field
def type(self) -> str:
    """[arrow type](https://arrow.apache.org/docs/python/api/dataset.html#classes)"""
    return type(self.source).__name__

graphique.middleware.GraphQL

Bases: GraphQL

ASGI GraphQL app with root value(s).

Parameters:

Name Type Description Default
root Source

root dataset to attach as the Query type

required
debug bool

enable timing extension

False
**kwargs

additional asgi.GraphQL options

{}
Source code in graphique/middleware.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class GraphQL(strawberry.asgi.GraphQL):
    """ASGI GraphQL app with root value(s).

    Args:
        root: root dataset to attach as the Query type
        debug: enable timing extension
        **kwargs: additional `asgi.GraphQL` options
    """

    options = dict(types=Column.registry.values(), scalar_overrides=scalar_map)

    def __init__(self, root: Source, debug: bool = False, **kwargs):
        options: dict = dict(self.options, extensions=(MetricsExtension,) * bool(debug))
        if type(root).__name__ == 'Query':
            self.root_value = root
            options['enable_federation_2'] = True
            schema = strawberry.federation.Schema(type(self.root_value), **options)
        else:
            self.root_value = implemented(root)
            schema = strawberry.Schema(type(self.root_value), **options)
        super().__init__(schema, debug=debug, **kwargs)

    async def get_root_value(self, request):
        return self.root_value

    @classmethod
    def federated(cls, roots: Mapping[str, Source], keys: Mapping[str, Iterable] = {}, **kwargs):
        """Construct GraphQL app with multiple federated datasets.

        Args:
            roots: mapping of field names to root datasets
            keys: mapping of optional federation keys for each root
            **kwargs: additional `asgi.GraphQL` options
        """
        root_values = {name: implemented(roots[name], name, keys.get(name, ())) for name in roots}
        annotations = {name: type(root_values[name]) for name in root_values}
        Query = type('Query', (), {'__annotations__': annotations})
        return cls(strawberry.type(Query)(**root_values), **kwargs)

federated(roots, keys={}, **kwargs) classmethod

Construct GraphQL app with multiple federated datasets.

Parameters:

Name Type Description Default
roots Mapping[str, Source]

mapping of field names to root datasets

required
keys Mapping[str, Iterable]

mapping of optional federation keys for each root

{}
**kwargs

additional asgi.GraphQL options

{}
Source code in graphique/middleware.py
65
66
67
68
69
70
71
72
73
74
75
76
77
@classmethod
def federated(cls, roots: Mapping[str, Source], keys: Mapping[str, Iterable] = {}, **kwargs):
    """Construct GraphQL app with multiple federated datasets.

    Args:
        roots: mapping of field names to root datasets
        keys: mapping of optional federation keys for each root
        **kwargs: additional `asgi.GraphQL` options
    """
    root_values = {name: implemented(roots[name], name, keys.get(name, ())) for name in roots}
    annotations = {name: type(root_values[name]) for name in root_values}
    Query = type('Query', (), {'__annotations__': annotations})
    return cls(strawberry.type(Query)(**root_values), **kwargs)