Skip to content

[PROTOTYPE] Inner heuristics refactor#1695

Open
csarofeen wants to merge 3 commits intodevelfrom
timm_nhwc
Open

[PROTOTYPE] Inner heuristics refactor#1695
csarofeen wants to merge 3 commits intodevelfrom
timm_nhwc

Conversation

@csarofeen
Copy link
Copy Markdown
Owner

Fixes #ISSUE_NUMBER

@csarofeen
Copy link
Copy Markdown
Owner Author

@jjsjann123 we may want this in for 1.12.

Comment on lines +26 to +59
int64_t roundUpPow2OrMultipleOf(const int64_t x, const int64_t multiple) {
auto round_up_pow2 = scheduler_utils::lastPow2(x);
if (round_up_pow2 < x) {
round_up_pow2 *= 2;
}
constexpr int64_t kEight = 8; // clang tidy
auto round_up_8 = x % kEight == 0 ? x : x + (kEight - x % kEight);
return std::min(round_up_8, round_up_pow2);
auto round_up_multiple =
x % multiple == 0 ? x : x + (multiple - x % multiple);
return std::min(round_up_multiple, round_up_pow2);
}

int64_t safeDiv(const int64_t x, const int64_t y) {
return x / y == 0 ? 1 : x / y;
}

int64_t clamp(const int64_t val, const int64_t min_val, const int64_t max_val) {
return std::min(std::max(val, min_val), max_val);
}

// Reduce x, y, z until it's product is less than max value, reduce round robin
// starting with x
void reduceProductTo(int64_t& x, int64_t& y, int64_t& z, const int64_t max) {
TORCH_INTERNAL_ASSERT(max > 1);
if (x * y * z > max) {
x = safeDiv(x, 2);
}
if (x * y * z > max) {
y = safeDiv(y, 2);
}
if (x * y * z > max) {
z = safeDiv(z, 2);
}
if (x * y * z > max) {
reduceProductTo(x, y, z, max);
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should merge them with those in the reduction scheduler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants