136 lines
4.4 KiB
C++
136 lines
4.4 KiB
C++
// Licensed to the Apache Software Foundation (ASF) under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing,
|
|
// software distributed under the License is distributed on an
|
|
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, either express or implied. See the License for the
|
|
// specific language governing permissions and limitations
|
|
// under the License.
|
|
|
|
#include "ipc.h"
|
|
|
|
#include <memory>
|
|
|
|
#include "arrow/compute/cast.h"
|
|
#include "arrow/python/pyarrow.h"
|
|
|
|
namespace arrow {
|
|
namespace py {
|
|
|
|
PyRecordBatchReader::PyRecordBatchReader() {}
|
|
|
|
Status PyRecordBatchReader::Init(std::shared_ptr<Schema> schema, PyObject* iterable) {
|
|
schema_ = std::move(schema);
|
|
|
|
iterator_.reset(PyObject_GetIter(iterable));
|
|
return CheckPyError();
|
|
}
|
|
|
|
std::shared_ptr<Schema> PyRecordBatchReader::schema() const { return schema_; }
|
|
|
|
Status PyRecordBatchReader::ReadNext(std::shared_ptr<RecordBatch>* batch) {
|
|
PyAcquireGIL lock;
|
|
|
|
if (!iterator_) {
|
|
// End of stream
|
|
batch->reset();
|
|
return Status::OK();
|
|
}
|
|
|
|
OwnedRef py_batch(PyIter_Next(iterator_.obj()));
|
|
if (!py_batch) {
|
|
RETURN_IF_PYERROR();
|
|
// End of stream
|
|
batch->reset();
|
|
iterator_.reset();
|
|
return Status::OK();
|
|
}
|
|
|
|
return unwrap_batch(py_batch.obj()).Value(batch);
|
|
}
|
|
|
|
Result<std::shared_ptr<RecordBatchReader>> PyRecordBatchReader::Make(
|
|
std::shared_ptr<Schema> schema, PyObject* iterable) {
|
|
auto reader = std::shared_ptr<PyRecordBatchReader>(new PyRecordBatchReader());
|
|
RETURN_NOT_OK(reader->Init(std::move(schema), iterable));
|
|
return reader;
|
|
}
|
|
|
|
CastingRecordBatchReader::CastingRecordBatchReader() = default;
|
|
|
|
Status CastingRecordBatchReader::Init(std::shared_ptr<RecordBatchReader> parent,
|
|
std::shared_ptr<Schema> schema) {
|
|
std::shared_ptr<Schema> src = parent->schema();
|
|
|
|
// The check for names has already been done in Python where it's easier to
|
|
// generate a nice error message.
|
|
int num_fields = schema->num_fields();
|
|
if (src->num_fields() != num_fields) {
|
|
return Status::Invalid("Number of fields not equal");
|
|
}
|
|
|
|
// Ensure all columns can be cast before succeeding
|
|
for (int i = 0; i < num_fields; i++) {
|
|
auto& src_type = src->field(i)->type();
|
|
auto& schema_type = schema->field(i)->type();
|
|
if (!src_type->Equals(schema_type) && !compute::CanCast(*src_type, *schema_type)) {
|
|
return Status::TypeError("Field ", i, " cannot be cast from ",
|
|
src->field(i)->type()->ToString(), " to ",
|
|
schema->field(i)->type()->ToString());
|
|
}
|
|
}
|
|
|
|
parent_ = std::move(parent);
|
|
schema_ = std::move(schema);
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
std::shared_ptr<Schema> CastingRecordBatchReader::schema() const { return schema_; }
|
|
|
|
Status CastingRecordBatchReader::ReadNext(std::shared_ptr<RecordBatch>* batch) {
|
|
std::shared_ptr<RecordBatch> out;
|
|
ARROW_RETURN_NOT_OK(parent_->ReadNext(&out));
|
|
if (!out) {
|
|
batch->reset();
|
|
return Status::OK();
|
|
}
|
|
|
|
auto num_columns = out->num_columns();
|
|
auto options = compute::CastOptions::Safe();
|
|
ArrayVector columns(num_columns);
|
|
for (int i = 0; i < num_columns; i++) {
|
|
const Array& src = *out->column(i);
|
|
if (!schema_->field(i)->nullable() && src.null_count() > 0) {
|
|
return Status::Invalid(
|
|
"Can't cast array that contains nulls to non-nullable field at index ", i);
|
|
}
|
|
|
|
ARROW_ASSIGN_OR_RAISE(columns[i],
|
|
compute::Cast(src, schema_->field(i)->type(), options));
|
|
}
|
|
|
|
*batch = RecordBatch::Make(schema_, out->num_rows(), std::move(columns));
|
|
return Status::OK();
|
|
}
|
|
|
|
Result<std::shared_ptr<RecordBatchReader>> CastingRecordBatchReader::Make(
|
|
std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema) {
|
|
auto reader = std::shared_ptr<CastingRecordBatchReader>(new CastingRecordBatchReader());
|
|
ARROW_RETURN_NOT_OK(reader->Init(parent, schema));
|
|
return reader;
|
|
}
|
|
|
|
Status CastingRecordBatchReader::Close() { return parent_->Close(); }
|
|
|
|
} // namespace py
|
|
} // namespace arrow
|