diff --git a/src/coreclr/vm/wasm/helpers.cpp b/src/coreclr/vm/wasm/helpers.cpp index b2a994e0ca0ef1..0359d5bc19c12b 100644 --- a/src/coreclr/vm/wasm/helpers.cpp +++ b/src/coreclr/vm/wasm/helpers.cpp @@ -3,6 +3,7 @@ // #include +#include "shash.h" extern "C" void STDCALL CallCountingStubCode() { @@ -436,7 +437,8 @@ namespace { // Arguments are passed on the stack with each argument aligned to INTERP_STACK_SLOT_SIZE. #define ARG_IND(i) ((int32_t)((int32_t*)(pArgs + (i * INTERP_STACK_SLOT_SIZE)))) -#define ARG(i) (*(int32_t*)ARG_IND(i)) +#define ARG_I32(i) (*(int32_t*)ARG_IND(i)) +#define ARG_F64(i) (*(double*)ARG_IND(i)) void CallFunc_Void_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { @@ -447,37 +449,37 @@ namespace void CallFunc_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t) = (void (*)(int32_t))pcode; - (*fptr)(ARG(0)); + (*fptr)(ARG_I32(0)); } void CallFunc_I32_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t) = (void (*)(int32_t, int32_t))pcode; - (*fptr)(ARG(0), ARG(1)); + (*fptr)(ARG_I32(0), ARG_I32(1)); } void CallFunc_I32_I32_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t, int32_t) = (void (*)(int32_t, int32_t, int32_t))pcode; - (*fptr)(ARG(0), ARG(1), ARG(2)); + (*fptr)(ARG_I32(0), ARG_I32(1), ARG_I32(2)); } void CallFunc_I32_I32_I32_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t, int32_t, int32_t) = (void (*)(int32_t, int32_t, int32_t, int32_t))pcode; - (*fptr)(ARG(0), ARG(1), ARG(2), ARG(3)); + (*fptr)(ARG_I32(0), ARG_I32(1), ARG_I32(2), ARG_I32(3)); } void CallFunc_I32_I32_I32_I32_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t, int32_t, int32_t, int32_t) = (void (*)(int32_t, int32_t, int32_t, int32_t, int32_t))pcode; - (*fptr)(ARG(0), ARG(1), ARG(2), ARG(3), ARG(4)); + (*fptr)(ARG_I32(0), ARG_I32(1), ARG_I32(2), ARG_I32(3), ARG_I32(4)); } void CallFunc_I32_I32_I32_I32_I32_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t, int32_t, int32_t, int32_t, int32_t) = (void (*)(int32_t, int32_t, int32_t, int32_t, int32_t, int32_t))pcode; - (*fptr)(ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5)); + (*fptr)(ARG_I32(0), ARG_I32(1), ARG_I32(2), ARG_I32(3), ARG_I32(4), ARG_I32(5)); } void CallFunc_Void_RetI32(PCODE pcode, int8_t *pArgs, int8_t *pRet) @@ -489,105 +491,93 @@ namespace void CallFunc_I32_RetI32(PCODE pcode, int8_t *pArgs, int8_t *pRet) { int32_t (*fptr)(int32_t) = (int32_t (*)(int32_t))pcode; - *(int32_t*)pRet = (*fptr)(ARG(0)); + *(int32_t*)pRet = (*fptr)(ARG_I32(0)); } void CallFunc_I32_I32_RetI32(PCODE pcode, int8_t *pArgs, int8_t *pRet) { int32_t (*fptr)(int32_t, int32_t) = (int32_t (*)(int32_t, int32_t))pcode; - *(int32_t*)pRet = (*fptr)(ARG(0), ARG(1)); + *(int32_t*)pRet = (*fptr)(ARG_I32(0), ARG_I32(1)); } void CallFunc_I32_I32_I32_RetI32(PCODE pcode, int8_t *pArgs, int8_t *pRet) { int32_t (*fptr)(int32_t, int32_t, int32_t) = (int32_t (*)(int32_t, int32_t, int32_t))pcode; - *(int32_t*)pRet = (*fptr)(ARG(0), ARG(1), ARG(2)); + *(int32_t*)pRet = (*fptr)(ARG_I32(0), ARG_I32(1), ARG_I32(2)); } void CallFunc_I32_I32_I32_I32_RetI32(PCODE pcode, int8_t *pArgs, int8_t *pRet) { int32_t (*fptr)(int32_t, int32_t, int32_t, int32_t) = (int32_t (*)(int32_t, int32_t, int32_t, int32_t))pcode; - *(int32_t*)pRet = (*fptr)(ARG(0), ARG(1), ARG(2), ARG(3)); + *(int32_t*)pRet = (*fptr)(ARG_I32(0), ARG_I32(1), ARG_I32(2), ARG_I32(3)); } - // Special thunks for signatures with indirect arguments. - void CallFunc_I32IND_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t) = (void (*)(int32_t))pcode; (*fptr)(ARG_IND(0)); } - void CallFunc_I32IND_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t) = (void (*)(int32_t, int32_t))pcode; - (*fptr)(ARG_IND(0), ARG(1)); + (*fptr)(ARG_IND(0), ARG_I32(1)); } void CallFunc_I32IND_I32_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t, int32_t) = (void (*)(int32_t, int32_t, int32_t))pcode; - (*fptr)(ARG_IND(0), ARG(1), ARG(2)); + (*fptr)(ARG_IND(0), ARG_I32(1), ARG_I32(2)); } void CallFunc_I32IND_I32_I32_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t, int32_t, int32_t) = (void (*)(int32_t, int32_t, int32_t, int32_t))pcode; - (*fptr)(ARG_IND(0), ARG(1), ARG(2), ARG(3)); + (*fptr)(ARG_IND(0), ARG_I32(1), ARG_I32(2), ARG_I32(3)); } void CallFunc_I32IND_I32_I32_I32_I32_I32_I32_RetVoid(PCODE pcode, int8_t *pArgs, int8_t *pRet) { void (*fptr)(int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t) = (void (*)(int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t))pcode; - (*fptr)(ARG_IND(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6)); + (*fptr)(ARG_IND(0), ARG_I32(1), ARG_I32(2), ARG_I32(3), ARG_I32(4), ARG_I32(5), ARG_I32(6)); } void CallFunc_I32IND_I32_RetI32(PCODE pcode, int8_t *pArgs, int8_t *pRet) { int32_t (*fptr)(int32_t, int32_t) = (int32_t (*)(int32_t, int32_t))pcode; - *(int32_t*)pRet = (*fptr)(ARG_IND(0), ARG(1)); + *(int32_t*)pRet = (*fptr)(ARG_IND(0), ARG_I32(1)); } void CallFunc_I32_I32IND_I32_I32IND_I32_RetI32(PCODE pcode, int8_t *pArgs, int8_t *pRet) { int32_t (*fptr)(int32_t, int32_t, int32_t, int32_t, int32_t) = (int32_t (*)(int32_t, int32_t, int32_t, int32_t, int32_t))pcode; - *(int32_t*)pRet = (*fptr)(ARG(0), ARG_IND(1), ARG(2), ARG_IND(3), ARG(4)); + *(int32_t*)pRet = (*fptr)(ARG_I32(0), ARG_IND(1), ARG_I32(2), ARG_IND(3), ARG_I32(4)); } void CallFunc_I32IND_I32_I32_I32_I32_I32_RetI32(PCODE pcode, int8_t *pArgs, int8_t *pRet) { int32_t (*fptr)(int32_t, int32_t, int32_t, int32_t, int32_t, int32_t) = (int32_t (*)(int32_t, int32_t, int32_t, int32_t, int32_t, int32_t))pcode; - *(int32_t*)pRet = (*fptr)(ARG_IND(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5)); + *(int32_t*)pRet = (*fptr)(ARG_IND(0), ARG_I32(1), ARG_I32(2), ARG_I32(3), ARG_I32(4), ARG_I32(5)); } -#undef ARG - - void* const RetVoidThunks[] = + void CallFunc_F64_RetF64(PCODE pcode, int8_t *pArgs, int8_t *pRet) { - (void*)&CallFunc_Void_RetVoid, - (void*)&CallFunc_I32_RetVoid, - (void*)&CallFunc_I32_I32_RetVoid, - (void*)&CallFunc_I32_I32_I32_RetVoid, - (void*)&CallFunc_I32_I32_I32_I32_RetVoid, - (void*)&CallFunc_I32_I32_I32_I32_I32_RetVoid, - (void*)&CallFunc_I32_I32_I32_I32_I32_I32_RetVoid, - }; + double (*fptr)(double) = (double (*)(double))pcode; + *(double*)pRet = (*fptr)(ARG_F64(0)); + } - void* const RetI32Thunks[] = - { - (void*)&CallFunc_Void_RetI32, - (void*)&CallFunc_I32_RetI32, - (void*)&CallFunc_I32_I32_RetI32, - (void*)&CallFunc_I32_I32_I32_RetI32, - (void*)&CallFunc_I32_I32_I32_I32_RetI32, - }; +#undef ARG_IND +#undef ARG_I32 +#undef ARG_F64 enum class ConvertType { NotConvertible, ToI32, - ToI32Indirect + ToI64, + ToI32Indirect, + ToF32, + ToF64 }; ConvertType ConvertibleTo(CorElementType argType, MetaSig& sig, bool isReturn) @@ -613,6 +603,13 @@ namespace case ELEMENT_TYPE_FNPTR: case ELEMENT_TYPE_SZARRAY: return ConvertType::ToI32; + case ELEMENT_TYPE_I8: + case ELEMENT_TYPE_U8: + return ConvertType::ToI64; + case ELEMENT_TYPE_R4: + return ConvertType::ToF32; + case ELEMENT_TYPE_R8: + return ConvertType::ToF64; case ELEMENT_TYPE_TYPEDBYREF: // Typed references are passed indirectly in WASM since they are larger than pointer size. return ConvertType::ToI32Indirect; @@ -639,97 +636,129 @@ namespace } } - void* ComputeCalliSigThunkSpecial(bool isVoidReturn, uint32_t numArgs, ConvertType* args) + char GetTypeCode(ConvertType type) + { + switch (type) + { + case ConvertType::ToI32: + return 'i'; + case ConvertType::ToI64: + return 'l'; + case ConvertType::ToF32: + return 'f'; + case ConvertType::ToF64: + return 'd'; + case ConvertType::ToI32Indirect: + return 'n'; + default: + PORTABILITY_ASSERT("Unknown type"); + return '?'; + } + } + + bool GetSignatureKey(MetaSig& sig, char* keyBuffer, uint32_t maxSize) { STANDARD_VM_CONTRACT; - if (isVoidReturn) + uint32_t pos = 0; + + if (sig.IsReturnTypeVoid()) { - switch(numArgs) - { - case 1: - if (args[0] == ConvertType::ToI32Indirect) - { - return (void*)&CallFunc_I32IND_RetVoid; - } - break; - case 2: - if (args[0] == ConvertType::ToI32Indirect && - args[1] == ConvertType::ToI32) - { - return (void*)&CallFunc_I32IND_I32_RetVoid; - } - break; - case 3: - if (args[0] == ConvertType::ToI32Indirect && - args[1] == ConvertType::ToI32 && - args[2] == ConvertType::ToI32) - { - return (void*)&CallFunc_I32IND_I32_I32_RetVoid; - } - break; - case 4: - if (args[0] == ConvertType::ToI32Indirect && - args[1] == ConvertType::ToI32 && - args[2] == ConvertType::ToI32 && - args[3] == ConvertType::ToI32) - { - return (void*)&CallFunc_I32IND_I32_I32_I32_RetVoid; - } - break; - case 7: - if (args[0] == ConvertType::ToI32Indirect && - args[1] == ConvertType::ToI32 && - args[2] == ConvertType::ToI32 && - args[3] == ConvertType::ToI32 && - args[4] == ConvertType::ToI32 && - args[5] == ConvertType::ToI32 && - args[6] == ConvertType::ToI32) - { - return (void*)&CallFunc_I32IND_I32_I32_I32_I32_I32_I32_RetVoid; - } - break; - } + keyBuffer[pos++] = 'v'; } else { - switch (numArgs) { - case 2: - if (args[0] == ConvertType::ToI32Indirect && - args[1] == ConvertType::ToI32) - { - return (void*)&CallFunc_I32IND_I32_RetI32; - } - break; - case 5: - if (args[0] == ConvertType::ToI32 && - args[1] == ConvertType::ToI32Indirect && - args[2] == ConvertType::ToI32 && - args[3] == ConvertType::ToI32Indirect && - args[4] == ConvertType::ToI32) - { - return (void*)&CallFunc_I32_I32IND_I32_I32IND_I32_RetI32; - } - break; - case 6: - if (args[0] == ConvertType::ToI32Indirect && - args[1] == ConvertType::ToI32 && - args[2] == ConvertType::ToI32 && - args[3] == ConvertType::ToI32 && - args[4] == ConvertType::ToI32 && - args[5] == ConvertType::ToI32) - { - return (void*)&CallFunc_I32IND_I32_I32_I32_I32_I32_RetI32; - } - break; + keyBuffer[pos++] = GetTypeCode(ConvertibleTo(sig.GetReturnType(), sig, true /* isReturn */)); + } + + if (sig.HasThis()) + keyBuffer[pos++] = 'i'; + + for (CorElementType argType = sig.NextArg(); + argType != ELEMENT_TYPE_END; + argType = sig.NextArg()) + { + if (pos >= maxSize) + return false; + + keyBuffer[pos++] = GetTypeCode(ConvertibleTo(argType, sig, false /* isReturn */)); + } + + if (pos >= maxSize) + return false; + + keyBuffer[pos] = 0; + + return true; + } + + struct StringToWasmSigThunk + { + const char* key; + void* value; + }; + + StringToWasmSigThunk wasmThunks[] = { + { "v", (void*)&CallFunc_Void_RetVoid }, + { "vi", (void*)&CallFunc_I32_RetVoid }, + { "vii", (void*)&CallFunc_I32_I32_RetVoid }, + { "viii", (void*)&CallFunc_I32_I32_I32_RetVoid }, + { "viiii", (void*)&CallFunc_I32_I32_I32_I32_RetVoid }, + { "viiiii", (void*)&CallFunc_I32_I32_I32_I32_I32_RetVoid }, + { "viiiiii", (void*)&CallFunc_I32_I32_I32_I32_I32_I32_RetVoid }, + + { "vn", (void*)&CallFunc_I32IND_RetVoid }, + { "vni", (void*)&CallFunc_I32IND_I32_RetVoid }, + { "vnii", (void*)&CallFunc_I32IND_I32_I32_RetVoid }, + { "vniii", (void*)&CallFunc_I32IND_I32_I32_I32_RetVoid }, + { "vniiiiii", (void*)&CallFunc_I32IND_I32_I32_I32_I32_I32_I32_RetVoid }, + + { "i", (void*)&CallFunc_Void_RetI32 }, + { "ii", (void*)&CallFunc_I32_RetI32 }, + { "iii", (void*)&CallFunc_I32_I32_RetI32 }, + { "iiii", (void*)&CallFunc_I32_I32_I32_RetI32 }, + { "iiiii", (void*)&CallFunc_I32_I32_I32_I32_RetI32 }, + + { "ini", (void*)&CallFunc_I32IND_I32_RetI32 }, + { "iinini", (void*)&CallFunc_I32_I32IND_I32_I32IND_I32_RetI32 }, + { "iniiiii", (void*)&CallFunc_I32IND_I32_I32_I32_I32_I32_RetI32 }, + + { "dd", (void*)&CallFunc_F64_RetF64 }, + }; + + class StringWasmThunkSHashTraits : public MapSHashTraits + { + public: + static BOOL Equals(const char* s1, const char* s2) { return strcmp(s1, s2) == 0; } + static count_t Hash(const char* key) { return HashStringA(key); } + }; + + typedef MapSHash> StringToWasmSigThunkHash; + static StringToWasmSigThunkHash* thunkCache = nullptr; + + void* LookupThunk(const char* key) + { + StringToWasmSigThunkHash* table = VolatileLoad(&thunkCache); + if (table == nullptr) + { + StringToWasmSigThunkHash* newTable = new StringToWasmSigThunkHash(); + for (const StringToWasmSigThunk& thunk : wasmThunks) + newTable->Add(thunk.key, thunk.value); + + if (InterlockedCompareExchangeT(&thunkCache, newTable, nullptr) != nullptr) + { + // Another thread won the race, discard ours + delete newTable; } + table = thunkCache; } - return NULL; + void* thunk; + bool success = table->Lookup(key, &thunk); + return success ? thunk : nullptr; } // This is a simple signature computation routine for signatures currently supported in the wasm environment. - // Note: Currently only validates void return type and i32 wasm convertible arguments. void* ComputeCalliSigThunk(MetaSig& sig) { STANDARD_VM_CONTRACT; @@ -749,57 +778,12 @@ namespace return NULL; } - // Check return value. We only support void or i32 return types for now. - bool returnsVoid = sig.IsReturnTypeVoid(); - if (!returnsVoid && ConvertibleTo(sig.GetReturnType(), sig, true /* isReturn */) != ConvertType::ToI32) + uint32_t keyBufferLen = sig.NumFixedArgs() + (sig.HasThis() ? 1 : 0) + 2; + char* keyBuffer = (char*)alloca(keyBufferLen); + if (!GetSignatureKey(sig, keyBuffer, keyBufferLen)) return NULL; - uint32_t numArgs = sig.NumFixedArgs() + (sig.HasThis() ? 1 : 0); - ConvertType args[16]; - _ASSERTE(numArgs < ARRAY_SIZE(args)); - - uint32_t i = 0; - - if (sig.HasThis()) - { - args[i++] = ConvertType::ToI32; - } - - // Ensure all arguments are wasm i32 compatible types. - for (CorElementType argType = sig.NextArg(); - argType != ELEMENT_TYPE_END; - argType = sig.NextArg()) - { - // If we have no conversion, immediately return. - ConvertType type = ConvertibleTo(argType, sig, false /* isReturn */); - if (type == ConvertType::NotConvertible) - return NULL; - - args[i++] = type; - } - - // Check for homogeneous i32 argument types. - for (uint32_t j = 0; j < numArgs; j++) - { - if (args[j] != ConvertType::ToI32) - return ComputeCalliSigThunkSpecial(returnsVoid, numArgs, args); - } - - void* const * thunks; - if (returnsVoid) - { - thunks = RetVoidThunks; - if (numArgs >= ARRAY_SIZE(RetVoidThunks)) - return NULL; - } - else - { - thunks = RetI32Thunks; - if (numArgs >= ARRAY_SIZE(RetI32Thunks)) - return NULL; - } - - return thunks[numArgs]; + return LookupThunk(keyBuffer); } }