diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2022-11-08 02:15:03 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2022-11-08 02:15:03 +0800 |
commit | 16332043f9f0a5db9dfc8ee4ae41e40da3227d3d (patch) | |
tree | 7cd3f66a974a24914cc60e78fa8f1a16215785fd /script | |
parent | 00dd1ed171beda2e892a2460d6e7eff321c775e8 (diff) | |
download | lua-language-server-16332043f9f0a5db9dfc8ee4ae41e40da3227d3d.zip |
`---@enum` supports runtime values
resolve #1619
Diffstat (limited to 'script')
-rw-r--r-- | script/core/diagnostics/assign-type-mismatch.lua | 6 | ||||
-rw-r--r-- | script/core/diagnostics/cast-local-type.lua | 6 | ||||
-rw-r--r-- | script/core/diagnostics/cast-type-mismatch.lua | 5 | ||||
-rw-r--r-- | script/core/diagnostics/param-type-mismatch.lua | 5 | ||||
-rw-r--r-- | script/vm/infer.lua | 4 | ||||
-rw-r--r-- | script/vm/type.lua | 289 |
6 files changed, 237 insertions, 78 deletions
diff --git a/script/core/diagnostics/assign-type-mismatch.lua b/script/core/diagnostics/assign-type-mismatch.lua index abe78ccd..2bbceebb 100644 --- a/script/core/diagnostics/assign-type-mismatch.lua +++ b/script/core/diagnostics/assign-type-mismatch.lua @@ -96,7 +96,8 @@ return function (uri, callback) end local varNode = vm.compileNode(source) - if vm.canCastType(uri, varNode, valueNode) then + local suc, errs = vm.canCastType(uri, varNode, valueNode) + if suc then return end @@ -107,12 +108,15 @@ return function (uri, callback) end end + assert(errs) + callback { start = source.start, finish = source.finish, message = lang.script('DIAG_ASSIGN_TYPE_MISMATCH', { def = vm.getInfer(varNode):view(uri), ref = vm.getInfer(valueNode):view(uri), + err = vm.viewTypeErrorMessage(uri, errs), }), } end) diff --git a/script/core/diagnostics/cast-local-type.lua b/script/core/diagnostics/cast-local-type.lua index c3d6e1bb..42271e91 100644 --- a/script/core/diagnostics/cast-local-type.lua +++ b/script/core/diagnostics/cast-local-type.lua @@ -34,13 +34,17 @@ return function (uri, callback) refNode = refNode:copy():setTruthy() end - if not vm.canCastType(uri, locNode, refNode) then + local suc, errs = vm.canCastType(uri, locNode, refNode) + + if not suc then + assert(errs) callback { start = ref.start, finish = ref.finish, message = lang.script('DIAG_CAST_LOCAL_TYPE', { def = vm.getInfer(locNode):view(uri), ref = vm.getInfer(refNode):view(uri), + err = vm.viewTypeErrorMessage(uri, errs), }), } end diff --git a/script/core/diagnostics/cast-type-mismatch.lua b/script/core/diagnostics/cast-type-mismatch.lua index a48e6cca..34b12559 100644 --- a/script/core/diagnostics/cast-type-mismatch.lua +++ b/script/core/diagnostics/cast-type-mismatch.lua @@ -26,13 +26,16 @@ return function (uri, callback) for _, cast in ipairs(doc.casts) do if not cast.mode and cast.extends then local refNode = vm.compileNode(cast.extends) - if not vm.canCastType(uri, defNode, refNode) then + local suc, errs = vm.canCastType(uri, defNode, refNode) + if not suc then + assert(errs) callback { start = cast.extends.start, finish = cast.extends.finish, message = lang.script('DIAG_CAST_TYPE_MISMATCH', { def = vm.getInfer(defNode):view(uri), ref = vm.getInfer(refNode):view(uri), + err = vm.viewTypeErrorMessage(uri, errs), }) } end diff --git a/script/core/diagnostics/param-type-mismatch.lua b/script/core/diagnostics/param-type-mismatch.lua index a2925e3e..9b2fbc6a 100644 --- a/script/core/diagnostics/param-type-mismatch.lua +++ b/script/core/diagnostics/param-type-mismatch.lua @@ -100,14 +100,17 @@ return function (uri, callback) -- 因此将假值移除再进行检查 refNode = refNode:copy():setTruthy() end - if not vm.canCastType(uri, defNode, refNode) then + local suc, errs = vm.canCastType(uri, defNode, refNode) + if not suc then local rawDefNode = getRawDefNode(funcNode, i) + assert(errs) callback { start = arg.start, finish = arg.finish, message = lang.script('DIAG_PARAM_TYPE_MISMATCH', { def = vm.getInfer(rawDefNode):view(uri), ref = vm.getInfer(refNode):view(uri), + err = vm.viewTypeErrorMessage(uri, errs), }) } end diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 1e531c1e..1c83faaa 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -234,6 +234,10 @@ local viewNodeSwitch;viewNodeSwitch = util.switch() end return ('fun(%s)%s'):format(argView, regView) end) + : case 'doc.field.name' + : call(function (source, infer, uri) + return vm.viewKey(source, uri) + end) ---@class vm.node ---@field lastInfer? vm.infer diff --git a/script/vm/type.lua b/script/vm/type.lua index feaaca9e..78b6d248 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -3,6 +3,9 @@ local vm = require 'vm.vm' local guide = require 'parser.guide' local config = require 'config.config' local util = require 'utility' +local lang = require 'language' + +---@alias typecheck.err vm.node.object|string|vm.node ---@param object vm.node.object ---@return string? @@ -47,61 +50,93 @@ end ---@param parentName string ---@param child vm.node.object ---@param uri uri +---@param err typecheck.err[] ---@return boolean? -local function checkEnum(parentName, child, uri) +local function checkEnum(parentName, child, uri, err) local parentClass = vm.getGlobal('type', parentName) if not parentClass then return nil end - local hasEnum + local enums + for _, set in ipairs(parentClass:getSets(uri)) do + if set.type == 'doc.enum' then + enums = vm.getEnums(set) + break + end + end + if not enums then + return nil + end if child.type == 'global' then ---@cast child vm.global - for _, set in ipairs(parentClass:getSets(uri)) do - if set.type == 'doc.enum' then - hasEnum = true - local enums = vm.getEnums(set) - if enums then - for _, enum in ipairs(enums) do - if vm.isSubType(uri, child, vm.compileNode(enum)) then - return true - end - end - end + for _, enum in ipairs(enums) do + if vm.isSubType(uri, child, vm.compileNode(enum)) then + return true end end + err[#err+1] = 'TYPE_ERROR_ENUM_GLOBAL_DISMATCH' + err[#err+1] = child + err[#err+1] = parentClass + return false + elseif child.type == 'generic' then + ---@cast child vm.generic + err[#err+1] = 'TYPE_ERROR_ENUM_GENERIC_UNSUPPORTED' + err[#err+1] = child + return false else - ---@cast child -vm.global - for _, set in ipairs(parentClass:getSets(uri)) do - if set.type == 'doc.enum' then - hasEnum = true - local myLiteral = vm.getInfer(child):viewLiterals() - local enums = vm.getEnums(set) - if enums then - for _, enum in ipairs(enums) do - if myLiteral == vm.getInfer(enum):viewLiterals() then - return true - end + ---@cast child parser.object + local childName = getNodeName(child) + if childName == 'number' + or childName == 'integer' + or childName == 'boolean' + or childName == 'string' then + for _, enum in ipairs(enums) do + for nd in vm.compileNode(enum):eachObject() do + if childName == getNodeName(nd) and nd[1] == child[1] then + return true end end end + err[#err+1] = 'TYPE_ERROR_ENUM_LITERAL_DISMATCH' + err[#err+1] = child[1] + err[#err+1] = parentClass + return false + elseif childName == 'function' + or childName == 'table' then + for _, enum in ipairs(enums) do + for nd in vm.compileNode(enum):eachObject() do + if child == nd then + return true + end + end + end + err[#err+1] = 'TYPE_ERROR_ENUM_OBJECT_DISMATCH' + err[#err+1] = child + err[#err+1] = parentClass + return false end - end - - if hasEnum then + err[#err+1] = 'TYPE_ERROR_ENUM_NO_OBJECT' + err[#err+1] = child return false end - return nil end ---@param parent vm.node.object ---@param child vm.node.object +---@param mark table +---@param err typecheck.err[] ---@return boolean -local function checkValue(parent, child) +local function checkValue(parent, child, mark, err) if parent.type == 'doc.type.integer' then if child.type == 'integer' or child.type == 'doc.type.integer' or child.type == 'number' then - return parent[1] == child[1] + if parent[1] ~= child[1] then + err[#err+1] = 'TYPE_ERROR_INTEGER_DISMATCH' + err[#err+1] = child + err[#err+1] = parent + return false + end end return true end @@ -111,7 +146,12 @@ local function checkValue(parent, child) if child.type == 'string' or child.type == 'doc.type.string' or child.type == 'doc.field.name' then - return parent[1] == child[1] + if parent[1] ~= child[1] then + err[#err+1] = 'TYPE_ERROR_STRING_DISMATCH' + err[#err+1] = child + err[#err+1] = parent + return false + end end return true end @@ -119,7 +159,12 @@ local function checkValue(parent, child) if parent.type == 'doc.type.boolean' then if child.type == 'boolean' or child.type == 'doc.type.boolean' then - return parent[1] == child[1] + if parent[1] ~= child[1] then + err[#err+1] = 'TYPE_ERROR_BOOLEAN_DISMATCH' + err[#err+1] = child + err[#err+1] = parent + return false + end end return true end @@ -132,12 +177,18 @@ local function checkValue(parent, child) local tnode = vm.compileNode(child) for _, pfield in ipairs(parent.fields) do local knode = vm.compileNode(pfield.name) - local pvalues = vm.compileNode(pfield.extends) local cvalues = vm.getTableValue(uri, tnode, knode, true) if not cvalues then + err[#err+1] = 'TYPE_ERROR_TABLE_NO_FIELD' + err[#err+1] = pfield.name return false end - if vm.isSubType(uri, cvalues, pvalues) == false then + local pvalues = vm.compileNode(pfield.extends) + if vm.isSubType(uri, cvalues, pvalues, mark, err) == false then + err[#err+1] = 'TYPE_ERROR_TABLE_FIELD_DISMATCH' + err[#err+1] = pfield.name + err[#err+1] = cvalues + err[#err+1] = pvalues return false end end @@ -168,14 +219,17 @@ end ---@param child vm.node|string|vm.node.object ---@param parent vm.node|string|vm.node.object ---@param mark? table ----@return boolean? -function vm.isSubType(uri, child, parent, mark) +---@param err? typecheck.err[] +---@return boolean|nil +---@return typecheck.err[] # errors +function vm.isSubType(uri, child, parent, mark, err) mark = mark or {} + err = err or {} if type(child) == 'string' then local global = vm.getGlobal('type', child) if not global then - return false + return nil, err end child = global elseif child.type == 'vm.node' then @@ -184,28 +238,39 @@ function vm.isSubType(uri, child, parent, mark) for n in child:eachObject() do if getNodeName(n) then hasKnownType = true - if vm.isSubType(uri, n, parent, mark) == true then - return true + if vm.isSubType(uri, n, parent, mark, err) == true then + return true, err end end end - return not hasKnownType + if hasKnownType then + err[#err+1] = 'TYPE_ERROR_UNION_ALL_DISMATCH' + err[#err+1] = child + err[#err+1] = parent + return false, err + end + return true, err else local weakNil = config.get(uri, 'Lua.type.weakNilCheck') for n in child:eachObject() do local nodeName = getNodeName(n) if nodeName and not (nodeName == 'nil' and weakNil) - and vm.isSubType(uri, n, parent, mark) == false then - return false + and vm.isSubType(uri, n, parent, mark, err) == false then + err[#err+1] = 'TYPE_ERROR_UNION_DISMATCH' + err[#err+1] = n + err[#err+1] = parent + return false, err end end if not weakNil and child:isOptional() then - if vm.isSubType(uri, 'nil', parent, mark) == false then - return false + if vm.isSubType(uri, 'nil', parent, mark, err) == false then + err[#err+1] = 'TYPE_ERROR_OPTIONAL_DISMATCH' + err[#err+1] = parent + return false, err end end - return true + return true, err end end @@ -213,18 +278,18 @@ function vm.isSubType(uri, child, parent, mark) local childName = getNodeName(child) if childName == 'any' or childName == 'unknown' then - return true + return true, err end if not childName or isAlias(childName, uri) then - return nil + return nil, err end if type(parent) == 'string' then local global = vm.getGlobal('type', parent) if not global then - return false + return false, err end parent = global elseif parent.type == 'vm.node' then @@ -232,17 +297,17 @@ function vm.isSubType(uri, child, parent, mark) for n in parent:eachObject() do if getNodeName(n) then hasKnownType = true - if vm.isSubType(uri, child, n, mark) == true then - return true + if vm.isSubType(uri, child, n, mark, err) == true then + return true, err end end end if parent:isOptional() then - if vm.isSubType(uri, child, 'nil', mark) == true then - return true + if vm.isSubType(uri, child, 'nil', mark, err) == true then + return true, err end end - return not hasKnownType + return not hasKnownType, err end ---@cast parent vm.node.object @@ -250,51 +315,54 @@ function vm.isSubType(uri, child, parent, mark) local parentName = getNodeName(parent) if parentName == 'any' or parentName == 'unknown' then - return true + return true, err end if not parentName or isAlias(parentName, uri) then - return nil + return nil, err end if childName == parentName then - if not checkValue(parent, child) then - return false + if not checkValue(parent, child, mark, err) then + return false, err end - return true + return true, err end if parentName == 'number' and childName == 'integer' then - return true + return true, err end if parentName == 'integer' and childName == 'number' then if config.get(uri, 'Lua.type.castNumberToInteger') then - return true + return true, err end if child.type == 'number' and child[1] and not math.tointeger(child[1]) then - return false + err[#err+1] = 'TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER' + err[#err+1] = child + return false, err end if child.type == 'global' and child.cate == 'type' then - return false + err[#err+1] = 'TYPE_ERROR_NUMBER_TYPE_TO_INTEGER' + return false, err end - return true + return true, err end - local isEnum = checkEnum(parentName, child, uri) + local isEnum = checkEnum(parentName, child, uri, err) if isEnum ~= nil then - return isEnum + return isEnum, err end if parentName == 'table' and not guide.isBasicType(childName) then - return true + return true, err end if childName == 'table' and not guide.isBasicType(parentName) then - return true + return true, err end -- check class parent @@ -308,8 +376,8 @@ function vm.isSubType(uri, child, parent, mark) for _, ext in ipairs(set.extends) do if ext.type == 'doc.extends.name' and (not isBasicType or guide.isBasicType(ext[1])) - and vm.isSubType(uri, ext[1], parent, mark) == true then - return true + and vm.isSubType(uri, ext[1], parent, mark, err) == true then + return true, err end end end @@ -326,11 +394,14 @@ function vm.isSubType(uri, child, parent, mark) ]] if guide.isBasicType(childName) and guide.isLiteral(child) - and vm.isSubType(uri, parentName, childName) then - return true + and vm.isSubType(uri, parentName, childName, mark, err) then + return true, err end - return false + err[#err+1] = 'TYPE_ERROR_DISMATCH' + err[#err+1] = child + err[#err+1] = parent + return false, err end ---@param node string|vm.node|vm.object @@ -474,6 +545,7 @@ end ---@param defNode vm.node ---@param refNode vm.node ---@return boolean +---@return typecheck.err[]? function vm.canCastType(uri, defNode, refNode) local defInfer = vm.getInfer(defNode) local refInfer = vm.getInfer(refNode) @@ -512,9 +584,78 @@ function vm.canCastType(uri, defNode, refNode) end end - if vm.isSubType(uri, refNode, defNode) then + local suc, err = vm.isSubType(uri, refNode, defNode) + + if suc then return true end - return false + return false, err +end + +local ErrorMessageMap = { + TYPE_ERROR_ENUM_GLOBAL_DISMATCH = {'child', 'parent'}, + TYPE_ERROR_ENUM_GENERIC_UNSUPPORTED = {'child'}, + TYPE_ERROR_ENUM_LITERAL_DISMATCH = {'child', 'parent'}, + TYPE_ERROR_ENUM_OBJECT_DISMATCH = {'child', 'parent'}, + TYPE_ERROR_ENUM_NO_OBJECT = {'child'}, + TYPE_ERROR_INTEGER_DISMATCH = {'child', 'parent'}, + TYPE_ERROR_STRING_DISMATCH = {'child', 'parent'}, + TYPE_ERROR_BOOLEAN_DISMATCH = {'child', 'parent'}, + TYPE_ERROR_TABLE_NO_FIELD = {'key'}, + TYPE_ERROR_TABLE_FIELD_DISMATCH = {'key', 'child', 'parent'}, + TYPE_ERROR_UNION_ALL_DISMATCH = {'child', 'parent'}, + TYPE_ERROR_UNION_DISMATCH = {'child', 'parent'}, + TYPE_ERROR_OPTIONAL_DISMATCH = {'parent'}, + TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER = {'child'}, + TYPE_ERROR_NUMBER_TYPE_TO_INTEGER = {}, + TYPE_ERROR_DISMATCH = {'child', 'parent'}, +} + +---@param uri uri +---@param errs typecheck.err[] +---@return string +function vm.viewTypeErrorMessage(uri, errs) + local lines = {} + local index = 1 + while true do + local name = errs[index] + if not name then + break + end + index = index + 1 + local params = ErrorMessageMap[name] + local lparams = {} + for _, paramName in ipairs(params) do + local value = errs[index] + if type(value) == 'string' + or type(value) == 'number' + or type(value) == 'boolean' then + lparams[paramName] = util.viewLiteral(value) + elseif value.type == 'global' then + lparams[paramName] = value.name + elseif value.type == 'vm.node' then + ---@cast value vm.node + lparams[paramName] = vm.getInfer(value):view(uri) + elseif value.type == 'table' then + lparams[paramName] = 'table' + elseif value.type == 'generic' then + ---@cast value vm.generic + lparams[paramName] = vm.viewObject(value, uri) + else + ---@cast value -string, -vm.global, -vm.node, -vm.generic + if paramName == 'key' then + lparams[paramName] = vm.viewKey(value, uri) + else + lparams[paramName] = vm.viewObject(value, uri) + or vm.getInfer(value):view(uri) + end + end + index = index + 1 + end + local line = lang.script(name, lparams) + lines[#lines+1] = '- ' .. line + end + util.revertTable(lines) + return table.concat(lines, '\n') end |