#pragma once
#include "core/framework/op_kernel.h"
#include "core/framework/func_api.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/graph/function.h"

namespace onnxruntime {

void* allocate_helper_func(void* allocator, size_t alignment, size_t size);

void release_helper_func(void* allocator, void* p);

//A kernel that wrapper the ComputeFunction call generated by execution provider when fuse the sub-graph
class FunctionKernel : public OpKernel {
 public:
  explicit FunctionKernel(const OpKernelInfo& info, const NodeComputeInfo* compute) : OpKernel(info), compute_info_(compute) {}

  //The original design is we load the dll, find the entry point and wrapper it.
  //Here for quick prototype, we keep the entry pointer in the node.
  static Status Create(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) {
    const NodeComputeInfo* compute;
    ORT_RETURN_IF_ERROR(func_mgr.GetFuncs(info.node().Name(), compute));
    std::unique_ptr<FunctionKernel> funckernel = std::make_unique<FunctionKernel>(info, compute);
    funckernel->num_inputs_ = info.node().InputDefs().size();
    funckernel->num_outputs_ = info.node().OutputDefs().size();

    if (compute->create_state_func) {
      //TODO: we are only provide host allocate method in compute context.
      //Do we need to hold the ref-counting here?
      funckernel->host_allocator_ = info.GetAllocator(0, OrtMemType::OrtMemTypeDefault);
      ComputeContext context = {allocate_helper_func, release_helper_func, funckernel->host_allocator_.get(),
                                info.node().Name().c_str()};
      int ret = funckernel->compute_info_->create_state_func(&context, &funckernel->func_state_);
      if (ret != 0)
        return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Create state function failed. Return value:", ret);
    }
    out = std::move(funckernel);
    return Status::OK();
  }

  ~FunctionKernel() override {
    if (compute_info_->release_state_func && func_state_) {
      compute_info_->release_state_func(func_state_);
    }
  }

  virtual Status Compute(OpKernelContext* context) const override {
    auto* context_internal = static_cast<OpKernelContextInternal*>(context);
    const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
    if (api == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "API VERSION ", ORT_API_VERSION, " is invalid.");
    return compute_info_->compute_func(func_state_, api,
                                       reinterpret_cast<OrtKernelContext*>(context_internal));
  }

  virtual Status Custom(char **node_name, void **node_data) const {
    if(compute_info_->custom_func)
    {
      return compute_info_->custom_func(func_state_, node_name, node_data);
    }
    else
    {
      return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No Custon Data \n");
    }
  }

 private:
  const NodeComputeInfo* const compute_info_;
  FunctionState func_state_{nullptr};
  size_t num_inputs_;
  size_t num_outputs_;
  AllocatorPtr host_allocator_;
};
}  // namespace onnxruntime
