From ce37fd67f83cd1e8793b988d2e4126bbf72b19dd Mon Sep 17 00:00:00 2001 From: alelenv <40001162+alelenv@users.noreply.github.com> Date: Fri, 31 Jan 2025 15:20:51 -0800 Subject: [PATCH] Add validation for SPV_NV_linear_swept_spheres. (#5975) --- DEPS | 2 +- source/val/validate_extensions.cpp | 3 +- source/val/validate_ray_query.cpp | 82 ++++++++++++++ source/val/validate_ray_tracing_reorder.cpp | 95 ++++++++++++++++ source/val/validation_state.cpp | 13 +++ source/val/validation_state.h | 1 + test/val/val_ray_query_test.cpp | 60 +++++++++- test/val/val_ray_tracing_reorder_test.cpp | 164 ++++++++++++++++++++++++++++ 8 files changed, 417 insertions(+), 3 deletions(-) diff --git a/DEPS b/DEPS index a1c347e8..0de5e4a8 100644 --- a/DEPS +++ b/DEPS @@ -14,7 +14,7 @@ vars = { 're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca', - 'spirv_headers_revision': '003bcf4e0d1922fb45e9b07656ee3db7c156a675', + 'spirv_headers_revision': 'e7294a8ebed84f8c5bd3686c68dbe12a4e65b644', } deps = { diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp index 64bb780c..af64e6a9 100644 --- a/source/val/validate_extensions.cpp +++ b/source/val/validate_extensions.cpp @@ -1059,7 +1059,8 @@ spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) { extension == ExtensionToString(kSPV_EXT_mesh_shader) || extension == ExtensionToString(kSPV_NV_shader_invocation_reorder) || extension == - ExtensionToString(kSPV_NV_cluster_acceleration_structure)) { + ExtensionToString(kSPV_NV_cluster_acceleration_structure) || + extension == ExtensionToString(kSPV_NV_linear_swept_spheres)) { return _.diag(SPV_ERROR_WRONG_VERSION, inst) << extension << " extension requires SPIR-V version 1.4 or later."; } diff --git a/source/val/validate_ray_query.cpp b/source/val/validate_ray_query.cpp index d7c75123..bd790ac3 100644 --- a/source/val/validate_ray_query.cpp +++ b/source/val/validate_ray_query.cpp @@ -23,6 +23,17 @@ namespace spvtools { namespace val { namespace { +uint32_t GetArrayLength(ValidationState_t& _, const Instruction* array_type) { + assert(array_type->opcode() == spv::Op::OpTypeArray); + uint32_t const_int_id = array_type->GetOperandAs(2U); + Instruction* array_length_inst = _.FindDef(const_int_id); + uint32_t array_length = 0; + if (array_length_inst->opcode() == spv::Op::OpConstant) { + array_length = array_length_inst->GetOperandAs(2); + } + return array_length; +} + spv_result_t ValidateRayQueryPointer(ValidationState_t& _, const Instruction* inst, uint32_t ray_query_index) { @@ -271,10 +282,81 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "expected Result Type to be 32-bit int scalar type"; } + break; + } + + case spv::Op::OpRayQueryGetIntersectionSpherePositionNV: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + + if (!_.IsFloatVectorType(result_type) || + _.GetDimension(result_type) != 3 || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type to be 32-bit float 3-component " + "vector type"; + } + break; + } + + case spv::Op::OpRayQueryGetIntersectionLSSPositionsNV: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + + auto result_id = _.FindDef(result_type); + if ((result_id->opcode() != spv::Op::OpTypeArray) || + (GetArrayLength(_, result_id) != 2) || + !_.IsFloatVectorType(_.GetComponentType(result_type)) || + _.GetDimension(_.GetComponentType(result_type)) != 3) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 2 element array of 32-bit 3 component float point " + "vector as Result Type: " + << spvOpcodeString(opcode); + } + break; + } + + case spv::Op::OpRayQueryGetIntersectionLSSRadiiNV: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + + if (!_.IsFloatArrayType(result_type) || + (GetArrayLength(_, _.FindDef(result_type)) != 2) || + !_.IsFloatScalarType(_.GetComponentType(result_type))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 32-bit floating point scalar as Result Type: " + << spvOpcodeString(opcode); + } + break; + } + + case spv::Op::OpRayQueryGetIntersectionSphereRadiusNV: + case spv::Op::OpRayQueryGetIntersectionLSSHitValueNV: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + if (!_.IsFloatScalarType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type to be 32-bit floating point " + "scalar type"; + } break; } + case spv::Op::OpRayQueryIsSphereHitNV: + case spv::Op::OpRayQueryIsLSSHitNV: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + + if (!_.IsBoolScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type to be Boolean " + "scalar type"; + } + + break; + } default: break; } diff --git a/source/val/validate_ray_tracing_reorder.cpp b/source/val/validate_ray_tracing_reorder.cpp index fd31cad0..3685a765 100644 --- a/source/val/validate_ray_tracing_reorder.cpp +++ b/source/val/validate_ray_tracing_reorder.cpp @@ -26,6 +26,17 @@ namespace val { static const uint32_t KRayParamInvalidId = std::numeric_limits::max(); +uint32_t GetArrayLength(ValidationState_t& _, const Instruction* array_type) { + assert(array_type->opcode() == spv::Op::OpTypeArray); + uint32_t const_int_id = array_type->GetOperandAs(2U); + Instruction* array_length_inst = _.FindDef(const_int_id); + uint32_t array_length = 0; + if (array_length_inst->opcode() == spv::Op::OpConstant) { + array_length = array_length_inst->GetOperandAs(2); + } + return array_length; +} + spv_result_t ValidateHitObjectPointer(ValidationState_t& _, const Instruction* inst, uint32_t hit_object_index) { @@ -628,6 +639,90 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) { break; } + case spv::Op::OpHitObjectGetSpherePositionNV: { + RegisterOpcodeForValidModel(_, inst); + if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error; + + if (!_.IsFloatVectorType(result_type) || + _.GetDimension(result_type) != 3 || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 32-bit floating point 2 component vector type as " + "Result Type: " + << spvOpcodeString(opcode); + } + break; + } + + case spv::Op::OpHitObjectGetSphereRadiusNV: { + RegisterOpcodeForValidModel(_, inst); + if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error; + + if (!_.IsFloatScalarType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 32-bit floating point scalar as Result Type: " + << spvOpcodeString(opcode); + } + break; + } + + case spv::Op::OpHitObjectGetLSSPositionsNV: { + RegisterOpcodeForValidModel(_, inst); + if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error; + + auto result_id = _.FindDef(result_type); + if ((result_id->opcode() != spv::Op::OpTypeArray) || + (GetArrayLength(_, result_id) != 2) || + !_.IsFloatVectorType(_.GetComponentType(result_type)) || + _.GetDimension(_.GetComponentType(result_type)) != 3) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 2 element array of 32-bit 3 component float point " + "vector as Result Type: " + << spvOpcodeString(opcode); + } + break; + } + + case spv::Op::OpHitObjectGetLSSRadiiNV: { + RegisterOpcodeForValidModel(_, inst); + if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error; + + if (!_.IsFloatArrayType(result_type) || + (GetArrayLength(_, _.FindDef(result_type)) != 2) || + !_.IsFloatScalarType(_.GetComponentType(result_type))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 2 element array of 32-bit floating point scalar as " + "Result Type: " + << spvOpcodeString(opcode); + } + break; + } + + case spv::Op::OpHitObjectIsSphereHitNV: { + RegisterOpcodeForValidModel(_, inst); + if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error; + + if (!_.IsBoolScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Boolean scalar as Result Type: " + << spvOpcodeString(opcode); + } + break; + } + + case spv::Op::OpHitObjectIsLSSHitNV: { + RegisterOpcodeForValidModel(_, inst); + if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error; + + if (!_.IsBoolScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Boolean scalar as Result Type: " + << spvOpcodeString(opcode); + } + break; + } + default: break; } diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 398f9b5a..4604d6da 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -951,6 +951,19 @@ bool ValidationState_t::IsFloatScalarType(uint32_t id) const { return inst && inst->opcode() == spv::Op::OpTypeFloat; } +bool ValidationState_t::IsFloatArrayType(uint32_t id) const { + const Instruction* inst = FindDef(id); + if (!inst) { + return false; + } + + if (inst->opcode() == spv::Op::OpTypeArray) { + return IsFloatScalarType(GetComponentType(id)); + } + + return false; +} + bool ValidationState_t::IsFloatVectorType(uint32_t id) const { const Instruction* inst = FindDef(id); if (!inst) { diff --git a/source/val/validation_state.h b/source/val/validation_state.h index cee3d9b2..e97d3d32 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -634,6 +634,7 @@ class ValidationState_t { // Only works for types not for objects. bool IsVoidType(uint32_t id) const; bool IsFloatScalarType(uint32_t id) const; + bool IsFloatArrayType(uint32_t id) const; bool IsFloatVectorType(uint32_t id) const; bool IsFloat16Vector2Or4Type(uint32_t id) const; bool IsFloatScalarOrVectorType(uint32_t id) const; diff --git a/test/val/val_ray_query_test.cpp b/test/val/val_ray_query_test.cpp index 52b9e9cd..ed6cce7d 100644 --- a/test/val/val_ray_query_test.cpp +++ b/test/val/val_ray_query_test.cpp @@ -86,6 +86,10 @@ OpDecorate %top_level_as Binding 0 %u32_0 = OpConstant %u32 0 %u64_0 = OpConstant %u64 0 +%u32_2 = OpConstant %u32 2 +%arr2v3 = OpTypeArray %f32vec3 %u32_2 +%arr2f3 = OpTypeArray %f32 %u32_2 + %u32vec3_0 = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0 %f32vec3_0 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0 %f32vec4_0 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0 @@ -105,7 +109,6 @@ OpDecorate %top_level_as Binding 0 ss << R"( %main = OpFunction %void None %func %main_entry = OpLabel - %ray_query = OpVariable %ptr_rq Function )"; @@ -647,6 +650,61 @@ TEST_F(ValidateRayQuery, ClusterASNV) { SPV_ENV_VULKAN_1_2); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); } + +using RayQueryLSSNVCommon = spvtest::ValidateBase; + +std::string RayQueryLSSNVResultType(std::string opcode, bool valid) { + if (opcode.compare("OpRayQueryGetIntersectionLSSPositionsNV") == 0) + return valid ? "%arr2v3" : "%f64"; + + if (opcode.compare("OpRayQueryGetIntersectionLSSRadiiNV") == 0) + return valid ? "%arr2f3" : "%f64"; + + if (opcode.compare("OpRayQueryGetIntersectionSphereRadiusNV") == 0 || + opcode.compare("OpRayQueryGetIntersectionLSSHitValueNV") == 0) { + return valid ? "%f32" : "%f64"; + } + + if (opcode.compare("OpRayQueryGetIntersectionSpherePositionNV") == 0) { + return valid ? "%f32vec3" : "%f64"; + } + + if (opcode.compare("OpRayQueryIsSphereHitNV") == 0 || + opcode.compare("OpRayQueryIsLSSHitNV") == 0) { + return valid ? "%bool" : "%f64"; + } + + return ""; +} + +TEST_P(RayQueryLSSNVCommon, Success) { + const std::string cap = R"( + OpCapability RayTracingSpheresGeometryNV + OpCapability RayTracingLinearSweptSpheresGeometryNV + )"; + const std::string ext = R"( + OpExtension "SPV_NV_linear_swept_spheres" + )"; + std::string opcode = GetParam(); + std::ostringstream ss; + ss << "%result = "; + ss << " " << opcode << " "; + ss << RayQueryLSSNVResultType(opcode, true); + ss << " %ray_query "; + ss << " %s32_0 "; + CompileSuccessfully(GenerateShaderCode(ss.str(), cap, ext).c_str(), + SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} + +INSTANTIATE_TEST_SUITE_P(ValidateRayQueryLSSNVCommon, RayQueryLSSNVCommon, + Values("OpRayQueryGetIntersectionSpherePositionNV", + "OpRayQueryGetIntersectionLSSPositionsNV", + "OpRayQueryGetIntersectionSphereRadiusNV", + "OpRayQueryGetIntersectionLSSRadiiNV", + "OpRayQueryGetIntersectionLSSHitValueNV", + "OpRayQueryIsSphereHitNV", + "OpRayQueryIsLSSHitNV")); } // namespace } // namespace val } // namespace spvtools diff --git a/test/val/val_ray_tracing_reorder_test.cpp b/test/val/val_ray_tracing_reorder_test.cpp index 91c39152..a41af80c 100644 --- a/test/val/val_ray_tracing_reorder_test.cpp +++ b/test/val/val_ray_tracing_reorder_test.cpp @@ -627,6 +627,170 @@ TEST_F(ValidateRayTracingReorderNV, ClusterASNV) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); } +TEST_F(ValidateRayTracingReorderNV, LSSGetSpherePositionNV) { + const std::string cap = R"( + OpCapability RayTracingSpheresGeometryNV + )"; + + const std::string ext = R"( + OpExtension "SPV_NV_linear_swept_spheres" + )"; + + const std::string declarations = R"( + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %_ptr_Function_v3float = OpTypePointer Function %v3float + )"; + + const std::string body = R"( + %pos = OpVariable %_ptr_Function_v3float Function + %result = OpHitObjectGetSpherePositionNV %v3float %hObj + OpStore %pos %result + )"; + + CompileSuccessfully( + GenerateReorderThreadCode(body, declarations, ext, cap).c_str(), + SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} + +TEST_F(ValidateRayTracingReorderNV, LSSGetLSSPositionsNV) { + const std::string cap = R"( + OpCapability RayTracingSpheresGeometryNV + OpCapability RayTracingLinearSweptSpheresGeometryNV + )"; + + const std::string ext = R"( + OpExtension "SPV_NV_linear_swept_spheres" + )"; + + const std::string declarations = R"( + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %v3float = OpTypeVector %float 3 + %uint_2 = OpConstant %uint 2 + %arr = OpTypeArray %v3float %uint_2 + %_ptr_Function_v3float = OpTypePointer Function %arr + )"; + + const std::string body = R"( + %lsspos = OpVariable %_ptr_Function_v3float Function + %result = OpHitObjectGetLSSPositionsNV %arr %hObj + OpStore %lsspos %result + )"; + + CompileSuccessfully( + GenerateReorderThreadCode(body, declarations, ext, cap).c_str(), + SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} + +TEST_F(ValidateRayTracingReorderNV, LSSGetSphereRadiusNV) { + const std::string cap = R"( + OpCapability RayTracingSpheresGeometryNV + )"; + + const std::string ext = R"( + OpExtension "SPV_NV_linear_swept_spheres" + )"; + + const std::string declarations = R"( + %float = OpTypeFloat 32 + %_ptr_Function_float = OpTypePointer Function %float + )"; + + const std::string body = R"( + %rad = OpVariable %_ptr_Function_float Function + %result = OpHitObjectGetSphereRadiusNV %float %hObj + OpStore %rad %result + )"; + + CompileSuccessfully( + GenerateReorderThreadCode(body, declarations, ext, cap).c_str(), + SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} + +TEST_F(ValidateRayTracingReorderNV, LSSGetLSSRadiiNV) { + const std::string cap = R"( + OpCapability RayTracingLinearSweptSpheresGeometryNV + )"; + + const std::string ext = R"( + OpExtension "SPV_NV_linear_swept_spheres" + )"; + + const std::string declarations = R"( + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 + %arr = OpTypeArray %float %uint_2 + %_ptr_Function_float = OpTypePointer Function %arr + )"; + + const std::string body = R"( + %rad = OpVariable %_ptr_Function_float Function + %result = OpHitObjectGetLSSRadiiNV %arr %hObj + OpStore %rad %result + )"; + + CompileSuccessfully( + GenerateReorderThreadCode(body, declarations, ext, cap).c_str(), + SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} + +TEST_F(ValidateRayTracingReorderNV, LSSIsSphereHitNV) { + const std::string cap = R"( + OpCapability RayTracingSpheresGeometryNV + )"; + + const std::string ext = R"( + OpExtension "SPV_NV_linear_swept_spheres" + )"; + + const std::string declarations = R"( + %bool = OpTypeBool + %_ptr_Function_bool = OpTypePointer Function %bool + )"; + + const std::string body = R"( + %ishit = OpVariable %_ptr_Function_bool Function + %result = OpHitObjectIsSphereHitNV %bool %hObj + OpStore %ishit %result + )"; + + CompileSuccessfully( + GenerateReorderThreadCode(body, declarations, ext, cap).c_str(), + SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} + +TEST_F(ValidateRayTracingReorderNV, LSSIsLSSHitNV) { + const std::string cap = R"( + OpCapability RayTracingLinearSweptSpheresGeometryNV + )"; + + const std::string ext = R"( + OpExtension "SPV_NV_linear_swept_spheres" + )"; + + const std::string declarations = R"( + %bool = OpTypeBool + %_ptr_Function_bool = OpTypePointer Function %bool + )"; + + const std::string body = R"( + %ishit = OpVariable %_ptr_Function_bool Function + %result = OpHitObjectIsLSSHitNV %bool %hObj + OpStore %ishit %result + )"; + + CompileSuccessfully( + GenerateReorderThreadCode(body, declarations, ext, cap).c_str(), + SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} } // namespace } // namespace val } // namespace spvtools -- 2.11.4.GIT