Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "exla_nif_util.h"
#include "ipc.h"
#include "mlir/IR/MLIRContext.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/register.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/pjrt/pjrt_api.h"
Expand Down Expand Up @@ -67,6 +69,9 @@ mlir_new_context(ErlNifEnv *env,
context->getOrLoadDialect<mlir::func::FuncDialect>();
context->getOrLoadDialect<mlir::stablehlo::StablehloDialect>();
context->getOrLoadDialect<mlir::chlo::ChloDialect>();
context->getOrLoadDialect<mlir::sdy::SdyDialect>();
mlir::sdy::registerAllDialects(
const_cast<mlir::DialectRegistry &>(context->getDialectRegistry()));

return context;
}
Expand Down Expand Up @@ -171,6 +176,69 @@ fine::Ok<> mlir_pop_region(ErlNifEnv *env,

FINE_NIF(mlir_pop_region, 0);

fine::Ok<> mlir_add_mesh(ErlNifEnv *env, fine::ResourcePtr<MLIRModule> module,
std::string mesh_name,
std::vector<std::tuple<std::string, int64_t>> axes) {
auto builder = module->builder();
auto context = module->module()->getContext();

llvm::SmallVector<mlir::sdy::MeshAxisAttr> axis_attrs;
for (auto [name, size] : axes) {
axis_attrs.push_back(mlir::sdy::MeshAxisAttr::get(context, name, size));
}

auto mesh_attr = mlir::sdy::MeshAttr::get(context, axis_attrs);

// Create the mesh op at the beginning of the module
auto module_op = module->module();
auto &body_region = module_op.getBodyRegion();
mlir::OpBuilder::InsertionGuard guard(*builder);
builder->setInsertionPointToStart(&body_region.front());

mlir::OperationState state(builder->getUnknownLoc(), "sdy.mesh");
mlir::sdy::MeshOp::build(*builder, state, mesh_name, mesh_attr);
builder->create(state);

return fine::Ok();
}

FINE_NIF(mlir_add_mesh, 0);

mlir::sdy::TensorShardingAttr mlir_create_tensor_sharding_attr(
mlir::MLIRContext *context, std::string mesh_name,
std::vector<std::vector<std::string>> dim_shardings) {
llvm::SmallVector<mlir::sdy::DimensionShardingAttr> dim_sharding_attrs;
for (const auto &dim : dim_shardings) {
llvm::SmallVector<mlir::sdy::AxisRefAttr> axis_refs;
for (const auto &axis : dim) {
axis_refs.push_back(mlir::sdy::AxisRefAttr::get(context, axis));
}
dim_sharding_attrs.push_back(mlir::sdy::DimensionShardingAttr::get(
context, axis_refs, /*is_closed=*/false, /*priority=*/0));
}

return mlir::sdy::TensorShardingAttr::get(
context, mesh_name, dim_sharding_attrs,
/*replicated_axes=*/llvm::ArrayRef<mlir::sdy::AxisRefAttr>(),
/*unreduced_axes=*/llvm::ArrayRef<mlir::sdy::AxisRefAttr>());
}

fine::Ok<>
mlir_set_arg_sharding(ErlNifEnv *env, fine::ResourcePtr<MLIRFunction> function,
int64_t arg_index, std::string mesh_name,
std::vector<std::vector<std::string>> dim_shardings) {

auto context = function->module()->module()->getContext();
auto sharding_attr =
mlir_create_tensor_sharding_attr(context, mesh_name, dim_shardings);

function->function().setArgAttr(arg_index, "sdy.sharding", sharding_attr);

return fine::Ok();
}

FINE_NIF(mlir_set_arg_sharding, 0);

mlir::Type mlir_get_typespec(ErlNifEnv *env,
fine::ResourcePtr<mlir::Value> value) {
return value->getType();
Expand Down
80 changes: 56 additions & 24 deletions exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,24 @@ PjRtBufferFromBinary(xla::PjRtClient *client, ERL_NIF_TERM source_term,
tsl::StatusOr<std::vector<std::vector<xla::PjRtBuffer *>>> UnpackRunArguments(
ErlNifEnv *env, ExlaExecutable::RunArguments arguments,
std::vector<std::unique_ptr<xla::PjRtBuffer>> &transient_buffers,
ExlaClient *client, xla::DeviceAssignment device_assignment,
int device_id) {
ExlaClient *client, xla::DeviceAssignment device_assignment, int device_id,
int num_partitions) {
std::vector<std::vector<xla::PjRtBuffer *>> arg_buffers;
arg_buffers.reserve(arguments.size());

int replica = 0;
int index = 0;

for (const auto &replica_arguments : arguments) {
auto device = device_id >= 0 ? device_id : device_assignment(replica, 0);
// For automatic SPMD: each input list goes to a different partition device
// device_assignment is (replica, partition) -> device
// With num_partitions > 1, we iterate through partitions (replica=0,
// partition=0..N-1) For replication, we iterate through replicas
// (replica=0..N-1, partition=0)
int replica = (num_partitions > 1) ? 0 : index;
int partition = (num_partitions > 1) ? index : 0;

auto device =
device_id >= 0 ? device_id : device_assignment(replica, partition);

auto replica_buffers = std::vector<xla::PjRtBuffer *>();
replica_buffers.reserve(replica_arguments.size());
Expand Down Expand Up @@ -200,7 +209,7 @@ tsl::StatusOr<std::vector<std::vector<xla::PjRtBuffer *>>> UnpackRunArguments(

arg_buffers.push_back(std::move(replica_buffers));

replica++;
index++;
}

return arg_buffers;
Expand All @@ -216,7 +225,17 @@ UnpackResult(ErlNifEnv *env,

for (int i = 0; i < result.size(); i++) {
auto replica_results = std::vector<fine::ResourcePtr<ExlaBuffer>>();
int64_t device = device_id >= 0 ? device_id : device_assignment(i, 0);

int64_t device;
if (device_id >= 0) {
device = device_id;
} else if (device_assignment.computation_count() > 1) {
// SPMD: results correspond to partitions (replica 0, partition i)
device = device_assignment(0, i);
} else {
// Replication: results correspond to replicas (replica i, partition 0)
device = device_assignment(i, 0);
}

for (auto &pjrt_buf : result.at(i)) {
pjrt_buf->GetReadyFuture().Await();
Expand Down Expand Up @@ -266,20 +285,23 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments,
// a pmap, but in all other cases it will be equal to 1
int num_replicas = executable_->num_replicas();

// the number of partitions is used for SPMD partitioning
int num_partitions = executable_->num_partitions();

// input buffers are a list of lists, where each list maps to the args
// to pass to one of the replicas in a computation, e.g. [replica_args1,
// replica_args2, ...]
std::vector<std::vector<xla::PjRtBuffer *>> input_buffers;

// the device assignment is a 2d array which maps coordinates (replica,
// partition) to a device; or in this case just maps a replica to a device
// partition) to a device
xla::DeviceAssignment device_assignment;
if (client_->client()->platform_name() == "METAL") {
device_assignment = xla::DeviceAssignment(1, 1);
} else {
EXLA_ASSIGN_OR_RETURN(
device_assignment,
client_->client()->GetDefaultDeviceAssignment(num_replicas, 1));
EXLA_ASSIGN_OR_RETURN(device_assignment,
client_->client()->GetDefaultDeviceAssignment(
num_replicas, num_partitions));
}

// Buffers allocated from binaries for this specific run need to be
Expand All @@ -300,15 +322,20 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments,
EXLA_ASSIGN_OR_RETURN(input_buffers,
UnpackRunArguments(env, arguments, transient_buffers,
client_, device_assignment,
device_id));
device_id, num_partitions));
}

// at this point input buffers is a vector of arguments per replica
// and the size of that vector should equal the number of replicas in the
// executable, otherwise it is invalid
if (num_replicas != input_buffers.size()) {
return xla::InvalidArgument("Got %d replica arguments for %d replicas",
input_buffers.size(), num_replicas);
// at this point input buffers is a vector of arguments per device
// For automatic SPMD: one input list per partition (num_partitions lists)
// For standard replication: one input list per replica (num_replicas lists)
// Each input list contains full unreplicated tensors; XLA slices based on
// sharding
int expected_lists = num_partitions > 1 ? num_partitions : num_replicas;
if (input_buffers.size() != expected_lists) {
return xla::InvalidArgument("Got %d argument lists, expected %d "
"(num_replicas=%d, num_partitions=%d)",
input_buffers.size(), expected_lists,
num_replicas, num_partitions);
}

std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>>
Expand All @@ -333,20 +360,25 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments,
// result buffers to unpack
per_replica_results.push_back(std::move(portable_result));
} else {
// no device ID is present, so it may be a replicated executable which means
// we need to use the replica execution path
// TODO: This now exposes a `returned_futures` API, does this make sense for
// us?
// no device ID is present, so it may be a replicated or SPMD executable
// For SPMD with num_partitions > 1, Execute handles partitioned execution
// using sharding annotations
EXLA_ASSIGN_OR_RETURN(per_replica_results,
executable_->Execute(input_buffers, options));
}

// EXLA_ASSIGN_OR_RETURN(per_replica_results,
// executable_->Execute(input_buffers, options));

// sanity check
if (per_replica_results.size() != num_replicas) {
return xla::FailedPrecondition("Invalid execution.");
// sanity check - for SPMD we get results per partition, for replication per
// replica
int expected_results = num_partitions > 1 ? num_partitions : num_replicas;
if (per_replica_results.size() != expected_results) {
return xla::FailedPrecondition(
"Invalid execution: got %d results, expected %d (num_replicas=%d, "
"num_partitions=%d)",
per_replica_results.size(), expected_results, num_replicas,
num_partitions);
}

// we need to unpack the results into Erlang terms, the result is a vector
Expand Down
2 changes: 2 additions & 0 deletions exla/c_src/exla/exla_mlir.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class MLIRFunction {

llvm::MutableArrayRef<mlir::BlockArgument> GetArguments() { return func_->getBody().front().getArguments(); }

mlir::func::FuncOp function() { return *func_; }

fine::ResourcePtr<MLIRModule> module() { return module_; }

private:
Expand Down
Loading