diff options
-rw-r--r-- | changelog.md | 2 | ||||
-rw-r--r-- | locale/en-us/setting.lua | 4 | ||||
-rw-r--r-- | locale/pt-br/setting.lua | 4 | ||||
-rw-r--r-- | locale/zh-cn/setting.lua | 4 | ||||
-rw-r--r-- | locale/zh-tw/setting.lua | 4 | ||||
-rw-r--r-- | script/config/template.lua | 8 | ||||
-rw-r--r-- | script/vm/type.lua | 130 | ||||
-rw-r--r-- | test/diagnostics/cast-local-type.lua | 5 | ||||
-rw-r--r-- | test/diagnostics/param-type-mismatch.lua | 7 |
9 files changed, 102 insertions, 66 deletions
diff --git a/changelog.md b/changelog.md index 6f872d20..0f1c2d10 100644 --- a/changelog.md +++ b/changelog.md @@ -2,7 +2,7 @@ ## Unreleased <!-- Add all new changes here. They will be moved under a version at release --> -* `NEW` Add matching checks between the shape of tables and classes, during type checking. [#2768](https://github.com/LuaLS/lua-language-server/pull/2768) +* `NEW` Setting: `Lua.type.checkTableShape`: Add matching checks between the shape of tables and classes, during type checking. [#2768](https://github.com/LuaLS/lua-language-server/pull/2768) * `FIX` Error `attempt to index a nil value` when `Lua.hint.semicolon == 'All'` [#2788](https://github.com/LuaLS/lua-language-server/issues/2788) * `FIX` Incorrect LuaCats parsing for `"'"` * `FIX` Incorrect indent fixings diff --git a/locale/en-us/setting.lua b/locale/en-us/setting.lua index eae37eae..645053ed 100644 --- a/locale/en-us/setting.lua +++ b/locale/en-us/setting.lua @@ -301,6 +301,10 @@ When a parameter type is not annotated, it is inferred from the function's call When this setting is `false`, the type of the parameter is `any` when it is not annotated. ]] +config.type.checkTableShape = +[[ +Strictly check the shape of the table. +]] config.doc.privateName = 'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.' config.doc.protectedName = diff --git a/locale/pt-br/setting.lua b/locale/pt-br/setting.lua index 3f7ae657..504fcf54 100644 --- a/locale/pt-br/setting.lua +++ b/locale/pt-br/setting.lua @@ -301,6 +301,10 @@ When the parameter type is not annotated, the parameter type is inferred from th When this setting is `false`, the type of the parameter is `any` when it is not annotated. ]] +config.type.checkTableShape = -- TODO: need translate! +[[ +对表的形状进行严格检查。 +]] config.doc.privateName = -- TODO: need translate! 'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.' config.doc.protectedName = -- TODO: need translate! diff --git a/locale/zh-cn/setting.lua b/locale/zh-cn/setting.lua index 9c6e9a25..0f65d857 100644 --- a/locale/zh-cn/setting.lua +++ b/locale/zh-cn/setting.lua @@ -300,6 +300,10 @@ config.type.inferParamType = 如果设置为 "false",则在未注释时,参数类型为 "any"。 ]] +config.type.checkTableShape = +[[ +对表的形状进行严格检查。 +]] config.doc.privateName = '将特定名称的字段视为私有,例如 `m_*` 意味着 `XXX.m_id` 与 `XXX.m_type` 是私有字段,只能在定义所在的类中访问。' config.doc.protectedName = diff --git a/locale/zh-tw/setting.lua b/locale/zh-tw/setting.lua index f15e2b4f..1442b6bc 100644 --- a/locale/zh-tw/setting.lua +++ b/locale/zh-tw/setting.lua @@ -300,6 +300,10 @@ config.type.inferParamType = -- TODO: need translate! 如果设置为 "false",则在未注释时,参数类型为 "any"。 ]] +config.type.checkTableShape = -- TODO: need translate! +[[ +对表的形状进行严格检查。 +]] config.doc.privateName = -- TODO: need translate! 'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.' config.doc.protectedName = -- TODO: need translate! diff --git a/script/config/template.lua b/script/config/template.lua index 7b044d7a..ee7dde37 100644 --- a/script/config/template.lua +++ b/script/config/template.lua @@ -4,9 +4,9 @@ local diag = require 'proto.diagnostic' ---@class config.unit ---@field caller function +---@field loader function ---@field _checker fun(self: config.unit, value: any): boolean ---@field name string ----@field [string] config.unit ---@operator shl: config.unit ---@operator shr: config.unit ---@operator call: config.unit @@ -57,7 +57,8 @@ local function register(name, default, checker, loader, caller) } end ----@type config.unit +---@class config.master +---@field [string] config.unit local Type = setmetatable({}, { __index = function (_, name) local unit = {} for k, v in pairs(units[name]) do @@ -398,7 +399,8 @@ local template = { ['Lua.type.castNumberToInteger'] = Type.Boolean >> true, ['Lua.type.weakUnionCheck'] = Type.Boolean >> false, ['Lua.type.weakNilCheck'] = Type.Boolean >> false, - ['Lua.type.inferParamType'] = Type.Boolean >> false, + ['Lua.type.inferParamType'] = Type.Boolean >> false, + ['Lua.type.checkTableShape'] = Type.Boolean >> false, ['Lua.doc.privateName'] = Type.Array(Type.String), ['Lua.doc.protectedName'] = Type.Array(Type.String), ['Lua.doc.packageName'] = Type.Array(Type.String), diff --git a/script/vm/type.lua b/script/vm/type.lua index 4835065a..3bc51cd8 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -284,6 +284,71 @@ local function isAlias(name, suri) return false end +local function checkTableShape(parent, child, uri, mark, errs) + local set = parent:getSets(uri) + local missedKeys = {} + local failedCheck + local myKeys + for _, def in ipairs(set) do + if not def.fields or #def.fields == 0 then + goto continue + end + if not myKeys then + myKeys = {} + for _, field in ipairs(child) do + local key = vm.getKeyName(field) or field.tindex + if key then + myKeys[key] = vm.compileNode(field) + end + end + end + + for _, field in ipairs(def.fields) do + local key = vm.getKeyName(field) + if not key then + local fieldnode = vm.compileNode(field.field)[1] + if fieldnode and fieldnode.type == 'doc.type.integer' then + ---@cast fieldnode parser.object + key = vm.getKeyName(fieldnode) + end + end + if not key then + goto continue + end + + local ok + local nodeField = vm.compileNode(field) + if myKeys[key] then + ok = vm.isSubType(uri, myKeys[key], nodeField, mark, errs) + if ok == false then + errs[#errs+1] = 'TYPE_ERROR_PARENT_ALL_DISMATCH' -- error display can be greatly improved + errs[#errs+1] = myKeys[key] + errs[#errs+1] = nodeField + failedCheck = true + end + elseif not nodeField:isNullable() then + if type(key) == "number" then + missedKeys[#missedKeys+1] = ('`[%s]`'):format(key) + else + missedKeys[#missedKeys+1] = ('`%s`'):format(key) + end + failedCheck = true + end + end + ::continue:: + end + if #missedKeys > 0 then + errs[#errs+1] = 'DIAG_MISSING_FIELDS' + errs[#errs+1] = parent + errs[#errs+1] = table.concat(missedKeys, ', ') + end + if failedCheck then + return false + end + + return true +end + ---@param uri uri ---@param child vm.node|string|vm.node.object ---@param parent vm.node|string|vm.node.object @@ -483,68 +548,11 @@ function vm.isSubType(uri, child, parent, mark, errs) return true end if childName == 'table' and not guide.isBasicType(parentName) then - local set = parent:getSets(uri) - local missedKeys = {} - local failedCheck - local myKeys - for _, def in ipairs(set) do - if not def.fields or #def.fields == 0 then - goto continue - end - if not myKeys then - myKeys = {} - for _, field in ipairs(child) do - local key = vm.getKeyName(field) or field.tindex - if key then - myKeys[key] = vm.compileNode(field) - end - end - end - - for _, field in ipairs(def.fields) do - local key = vm.getKeyName(field) - if not key then - local fieldnode = vm.compileNode(field.field)[1] - if fieldnode and fieldnode.type == 'doc.type.integer' then - ---@cast fieldnode parser.object - key = vm.getKeyName(fieldnode) - end - end - if not key then - goto continue - end - - local ok - local nodeField = vm.compileNode(field) - if myKeys[key] then - ok = vm.isSubType(uri, myKeys[key], nodeField, mark, errs) - if ok == false then - errs[#errs+1] = 'TYPE_ERROR_PARENT_ALL_DISMATCH' -- error display can be greatly improved - errs[#errs+1] = myKeys[key] - errs[#errs+1] = nodeField - failedCheck = true - end - elseif not nodeField:isNullable() then - if type(key) == "number" then - missedKeys[#missedKeys+1] = ('`[%s]`'):format(key) - else - missedKeys[#missedKeys+1] = ('`%s`'):format(key) - end - failedCheck = true - end - end - ::continue:: - end - if #missedKeys > 0 then - errs[#errs+1] = 'DIAG_MISSING_FIELDS' - errs[#errs+1] = parent - errs[#errs+1] = table.concat(missedKeys, ', ') - end - if failedCheck then - return false + if config.get(uri, 'Lua.type.checkTableShape') then + return checkTableShape(parent, child, uri, mark, errs) + else + return true end - - return true end -- check class parent diff --git a/test/diagnostics/cast-local-type.lua b/test/diagnostics/cast-local-type.lua index 93452a92..d702adf3 100644 --- a/test/diagnostics/cast-local-type.lua +++ b/test/diagnostics/cast-local-type.lua @@ -1,3 +1,4 @@ +local config = require "config.config" TEST [[ local x = 0 @@ -345,6 +346,8 @@ local v v = a ]] +config.set(nil, 'Lua.type.checkTableShape', true) + TEST [[ ---@class A ---@field x string @@ -398,3 +401,5 @@ local a = {x = "b", y = {}} local v <!v!> = a ]] + +config.set(nil, 'Lua.type.checkTableShape', false) diff --git a/test/diagnostics/param-type-mismatch.lua b/test/diagnostics/param-type-mismatch.lua index bb602cab..b75e9307 100644 --- a/test/diagnostics/param-type-mismatch.lua +++ b/test/diagnostics/param-type-mismatch.lua @@ -1,3 +1,4 @@ +local config = require "config.config" TEST [[ ---@param x number local function f(x) end @@ -278,6 +279,8 @@ function f(a) end f(a) ]] +config.set(nil, 'Lua.type.checkTableShape', true) + TEST [[ ---@class A ---@field x string @@ -347,4 +350,6 @@ local a = {} function f(a) end f(a) -]]
\ No newline at end of file +]] + +config.set(nil, 'Lua.type.checkTableShape', false) |