Some checks are pending
Copilot Setup Steps / copilot-setup-steps (push) Waiting to run
Check Pre-Tokenizer Hashes / pre-tokenizer-hashes (push) Waiting to run
Python check requirements.txt / check-requirements (push) Waiting to run
Python Type-Check / python type-check (push) Waiting to run
Update Operations Documentation / update-ops-docs (push) Waiting to run
63 lines
2.3 KiB
C++
63 lines
2.3 KiB
C++
#include "../node_context.h"
|
|
#include "../op_table.h"
|
|
#include "../utils.h"
|
|
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <openvino/core/node_output.hpp>
|
|
#include <openvino/op/constant.hpp>
|
|
#include <openvino/op/multiply.hpp>
|
|
#include <openvino/op/sigmoid.hpp>
|
|
#include <openvino/op/slice.hpp>
|
|
|
|
namespace ov {
|
|
namespace frontend {
|
|
namespace ggml {
|
|
namespace op {
|
|
|
|
OutputVector translate_glu_swiglu(const NodeContext & context) {
|
|
num_inputs_check(context, 1, 2);
|
|
|
|
ov::Output<ov::Node> src0;
|
|
ov::Output<ov::Node> src1;
|
|
if (context.get_input_size() == 2) {
|
|
src0 = context.get_input(0);
|
|
src1 = context.get_input(1);
|
|
} else {
|
|
// GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2.
|
|
// Both halves are nc elements; if the dimension is odd, the last element is dropped.
|
|
// Use Slice instead of Split to handle odd dimensions correctly.
|
|
auto combined = context.get_input(0);
|
|
auto combined_shape = combined.get_partial_shape();
|
|
int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length();
|
|
int64_t nc = last_dim_val / 2;
|
|
|
|
auto axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
|
|
auto step = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
|
auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
|
auto stop0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});
|
|
auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});
|
|
auto stop1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc});
|
|
|
|
src0 = std::make_shared<ov::op::v8::Slice>(combined, start0, stop0, step, axis);
|
|
src1 = std::make_shared<ov::op::v8::Slice>(combined, start1, stop1, step, axis);
|
|
}
|
|
|
|
int32_t * params = context.get_output_op_params();
|
|
const int32_t swapped = params[1];
|
|
if (swapped) {
|
|
std::swap(src0, src1);
|
|
}
|
|
|
|
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
|
|
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
|
|
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);
|
|
|
|
return rename_outputs_with_suffix({res}, context.get_name());
|
|
}
|
|
|
|
} // namespace op
|
|
} // namespace ggml
|
|
} // namespace frontend
|
|
} // namespace ov
|