clc: Declare LLVMContexts on the stack

This prevents more use-after-free errors.  Passing them around using
std::unique_ptr ensures that the LLVMContext gets destroyed but doesn't
ensure destruction order.  Declaring it on the stack ensures that the
context doesn't get destroyed until right before the the function
returns which is after any other LLVM stuff is destroyed.

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Icecream95 <ixn@disroot.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15937>
This commit is contained in:
Jason Ekstrand 2022-04-13 16:52:17 -05:00 committed by Marge Bot
parent 6099e6ce9a
commit 46d9b0e431
1 changed files with 30 additions and 20 deletions

View File

@ -747,16 +747,11 @@ clc_free_kernels_info(const struct clc_kernel_info *kernels,
free((void *)kernels);
}
static std::pair<std::unique_ptr<::llvm::Module>, std::unique_ptr<LLVMContext>>
clc_compile_to_llvm_module(const struct clc_compile_args *args,
static std::unique_ptr<::llvm::Module>
clc_compile_to_llvm_module(LLVMContext &llvm_ctx,
const struct clc_compile_args *args,
const struct clc_logger *logger)
{
clc_initialize_llvm();
std::unique_ptr<LLVMContext> llvm_ctx { new LLVMContext };
llvm_ctx->setDiagnosticHandlerCallBack(llvm_log_handler,
const_cast<clc_logger *>(logger));
std::string diag_log_str;
raw_string_ostream diag_log_stream { diag_log_str };
@ -878,14 +873,14 @@ clc_compile_to_llvm_module(const struct clc_compile_args *args,
::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
// Compile the code
clang::EmitLLVMOnlyAction act(llvm_ctx.get());
clang::EmitLLVMOnlyAction act(&llvm_ctx);
if (!c->ExecuteAction(act)) {
clc_error(logger, "%sError executing LLVM compilation action.\n",
diag_log_str.c_str());
return {};
}
return { act.takeModule(), std::move(llvm_ctx) };
return act.takeModule();
}
static SPIRV::VersionNumber
@ -906,7 +901,7 @@ spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
static int
llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
std::unique_ptr<LLVMContext> context,
LLVMContext &context,
const struct clc_compile_args *args,
const struct clc_logger *logger,
struct clc_binary *out_spirv)
@ -969,13 +964,19 @@ clc_c_to_spir(const struct clc_compile_args *args,
const struct clc_logger *logger,
struct clc_binary *out_spir)
{
auto pair = clc_compile_to_llvm_module(args, logger);
if (!pair.first)
clc_initialize_llvm();
LLVMContext llvm_ctx;
llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
const_cast<clc_logger *>(logger));
auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
if (!mod)
return -1;
::llvm::SmallVector<char, 0> buffer;
::llvm::BitcodeWriter writer(buffer);
writer.writeModule(*pair.first);
writer.writeModule(*mod);
out_spir->size = buffer.size_in_bytes();
out_spir->data = malloc(out_spir->size);
@ -989,10 +990,16 @@ clc_c_to_spirv(const struct clc_compile_args *args,
const struct clc_logger *logger,
struct clc_binary *out_spirv)
{
auto pair = clc_compile_to_llvm_module(args, logger);
if (!pair.first)
clc_initialize_llvm();
LLVMContext llvm_ctx;
llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
const_cast<clc_logger *>(logger));
auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
if (!mod)
return -1;
return llvm_mod_to_spirv(std::move(pair.first), std::move(pair.second), args, logger, out_spirv);
return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv);
}
int
@ -1002,13 +1009,16 @@ clc_spir_to_spirv(const struct clc_binary *in_spir,
{
clc_initialize_llvm();
std::unique_ptr<LLVMContext> llvm_ctx{ new LLVMContext };
LLVMContext llvm_ctx;
llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
const_cast<clc_logger *>(logger));
::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), *llvm_ctx);
auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), llvm_ctx);
if (!mod)
return -1;
return llvm_mod_to_spirv(std::move(mod.get()), std::move(llvm_ctx), NULL, logger, out_spirv);
return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv);
}
class SPIRVMessageConsumer {