summaryrefslogtreecommitdiff
path: root/script/vm
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm')
-rw-r--r--script/vm/infer.lua4
-rw-r--r--script/vm/type.lua289
2 files changed, 219 insertions, 74 deletions
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 1e531c1e..1c83faaa 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -234,6 +234,10 @@ local viewNodeSwitch;viewNodeSwitch = util.switch()
end
return ('fun(%s)%s'):format(argView, regView)
end)
+ : case 'doc.field.name'
+ : call(function (source, infer, uri)
+ return vm.viewKey(source, uri)
+ end)
---@class vm.node
---@field lastInfer? vm.infer
diff --git a/script/vm/type.lua b/script/vm/type.lua
index feaaca9e..78b6d248 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -3,6 +3,9 @@ local vm = require 'vm.vm'
local guide = require 'parser.guide'
local config = require 'config.config'
local util = require 'utility'
+local lang = require 'language'
+
+---@alias typecheck.err vm.node.object|string|vm.node
---@param object vm.node.object
---@return string?
@@ -47,61 +50,93 @@ end
---@param parentName string
---@param child vm.node.object
---@param uri uri
+---@param err typecheck.err[]
---@return boolean?
-local function checkEnum(parentName, child, uri)
+local function checkEnum(parentName, child, uri, err)
local parentClass = vm.getGlobal('type', parentName)
if not parentClass then
return nil
end
- local hasEnum
+ local enums
+ for _, set in ipairs(parentClass:getSets(uri)) do
+ if set.type == 'doc.enum' then
+ enums = vm.getEnums(set)
+ break
+ end
+ end
+ if not enums then
+ return nil
+ end
if child.type == 'global' then
---@cast child vm.global
- for _, set in ipairs(parentClass:getSets(uri)) do
- if set.type == 'doc.enum' then
- hasEnum = true
- local enums = vm.getEnums(set)
- if enums then
- for _, enum in ipairs(enums) do
- if vm.isSubType(uri, child, vm.compileNode(enum)) then
- return true
- end
- end
- end
+ for _, enum in ipairs(enums) do
+ if vm.isSubType(uri, child, vm.compileNode(enum)) then
+ return true
end
end
+ err[#err+1] = 'TYPE_ERROR_ENUM_GLOBAL_DISMATCH'
+ err[#err+1] = child
+ err[#err+1] = parentClass
+ return false
+ elseif child.type == 'generic' then
+ ---@cast child vm.generic
+ err[#err+1] = 'TYPE_ERROR_ENUM_GENERIC_UNSUPPORTED'
+ err[#err+1] = child
+ return false
else
- ---@cast child -vm.global
- for _, set in ipairs(parentClass:getSets(uri)) do
- if set.type == 'doc.enum' then
- hasEnum = true
- local myLiteral = vm.getInfer(child):viewLiterals()
- local enums = vm.getEnums(set)
- if enums then
- for _, enum in ipairs(enums) do
- if myLiteral == vm.getInfer(enum):viewLiterals() then
- return true
- end
+ ---@cast child parser.object
+ local childName = getNodeName(child)
+ if childName == 'number'
+ or childName == 'integer'
+ or childName == 'boolean'
+ or childName == 'string' then
+ for _, enum in ipairs(enums) do
+ for nd in vm.compileNode(enum):eachObject() do
+ if childName == getNodeName(nd) and nd[1] == child[1] then
+ return true
end
end
end
+ err[#err+1] = 'TYPE_ERROR_ENUM_LITERAL_DISMATCH'
+ err[#err+1] = child[1]
+ err[#err+1] = parentClass
+ return false
+ elseif childName == 'function'
+ or childName == 'table' then
+ for _, enum in ipairs(enums) do
+ for nd in vm.compileNode(enum):eachObject() do
+ if child == nd then
+ return true
+ end
+ end
+ end
+ err[#err+1] = 'TYPE_ERROR_ENUM_OBJECT_DISMATCH'
+ err[#err+1] = child
+ err[#err+1] = parentClass
+ return false
end
- end
-
- if hasEnum then
+ err[#err+1] = 'TYPE_ERROR_ENUM_NO_OBJECT'
+ err[#err+1] = child
return false
end
- return nil
end
---@param parent vm.node.object
---@param child vm.node.object
+---@param mark table
+---@param err typecheck.err[]
---@return boolean
-local function checkValue(parent, child)
+local function checkValue(parent, child, mark, err)
if parent.type == 'doc.type.integer' then
if child.type == 'integer'
or child.type == 'doc.type.integer'
or child.type == 'number' then
- return parent[1] == child[1]
+ if parent[1] ~= child[1] then
+ err[#err+1] = 'TYPE_ERROR_INTEGER_DISMATCH'
+ err[#err+1] = child
+ err[#err+1] = parent
+ return false
+ end
end
return true
end
@@ -111,7 +146,12 @@ local function checkValue(parent, child)
if child.type == 'string'
or child.type == 'doc.type.string'
or child.type == 'doc.field.name' then
- return parent[1] == child[1]
+ if parent[1] ~= child[1] then
+ err[#err+1] = 'TYPE_ERROR_STRING_DISMATCH'
+ err[#err+1] = child
+ err[#err+1] = parent
+ return false
+ end
end
return true
end
@@ -119,7 +159,12 @@ local function checkValue(parent, child)
if parent.type == 'doc.type.boolean' then
if child.type == 'boolean'
or child.type == 'doc.type.boolean' then
- return parent[1] == child[1]
+ if parent[1] ~= child[1] then
+ err[#err+1] = 'TYPE_ERROR_BOOLEAN_DISMATCH'
+ err[#err+1] = child
+ err[#err+1] = parent
+ return false
+ end
end
return true
end
@@ -132,12 +177,18 @@ local function checkValue(parent, child)
local tnode = vm.compileNode(child)
for _, pfield in ipairs(parent.fields) do
local knode = vm.compileNode(pfield.name)
- local pvalues = vm.compileNode(pfield.extends)
local cvalues = vm.getTableValue(uri, tnode, knode, true)
if not cvalues then
+ err[#err+1] = 'TYPE_ERROR_TABLE_NO_FIELD'
+ err[#err+1] = pfield.name
return false
end
- if vm.isSubType(uri, cvalues, pvalues) == false then
+ local pvalues = vm.compileNode(pfield.extends)
+ if vm.isSubType(uri, cvalues, pvalues, mark, err) == false then
+ err[#err+1] = 'TYPE_ERROR_TABLE_FIELD_DISMATCH'
+ err[#err+1] = pfield.name
+ err[#err+1] = cvalues
+ err[#err+1] = pvalues
return false
end
end
@@ -168,14 +219,17 @@ end
---@param child vm.node|string|vm.node.object
---@param parent vm.node|string|vm.node.object
---@param mark? table
----@return boolean?
-function vm.isSubType(uri, child, parent, mark)
+---@param err? typecheck.err[]
+---@return boolean|nil
+---@return typecheck.err[] # errors
+function vm.isSubType(uri, child, parent, mark, err)
mark = mark or {}
+ err = err or {}
if type(child) == 'string' then
local global = vm.getGlobal('type', child)
if not global then
- return false
+ return nil, err
end
child = global
elseif child.type == 'vm.node' then
@@ -184,28 +238,39 @@ function vm.isSubType(uri, child, parent, mark)
for n in child:eachObject() do
if getNodeName(n) then
hasKnownType = true
- if vm.isSubType(uri, n, parent, mark) == true then
- return true
+ if vm.isSubType(uri, n, parent, mark, err) == true then
+ return true, err
end
end
end
- return not hasKnownType
+ if hasKnownType then
+ err[#err+1] = 'TYPE_ERROR_UNION_ALL_DISMATCH'
+ err[#err+1] = child
+ err[#err+1] = parent
+ return false, err
+ end
+ return true, err
else
local weakNil = config.get(uri, 'Lua.type.weakNilCheck')
for n in child:eachObject() do
local nodeName = getNodeName(n)
if nodeName
and not (nodeName == 'nil' and weakNil)
- and vm.isSubType(uri, n, parent, mark) == false then
- return false
+ and vm.isSubType(uri, n, parent, mark, err) == false then
+ err[#err+1] = 'TYPE_ERROR_UNION_DISMATCH'
+ err[#err+1] = n
+ err[#err+1] = parent
+ return false, err
end
end
if not weakNil and child:isOptional() then
- if vm.isSubType(uri, 'nil', parent, mark) == false then
- return false
+ if vm.isSubType(uri, 'nil', parent, mark, err) == false then
+ err[#err+1] = 'TYPE_ERROR_OPTIONAL_DISMATCH'
+ err[#err+1] = parent
+ return false, err
end
end
- return true
+ return true, err
end
end
@@ -213,18 +278,18 @@ function vm.isSubType(uri, child, parent, mark)
local childName = getNodeName(child)
if childName == 'any'
or childName == 'unknown' then
- return true
+ return true, err
end
if not childName
or isAlias(childName, uri) then
- return nil
+ return nil, err
end
if type(parent) == 'string' then
local global = vm.getGlobal('type', parent)
if not global then
- return false
+ return false, err
end
parent = global
elseif parent.type == 'vm.node' then
@@ -232,17 +297,17 @@ function vm.isSubType(uri, child, parent, mark)
for n in parent:eachObject() do
if getNodeName(n) then
hasKnownType = true
- if vm.isSubType(uri, child, n, mark) == true then
- return true
+ if vm.isSubType(uri, child, n, mark, err) == true then
+ return true, err
end
end
end
if parent:isOptional() then
- if vm.isSubType(uri, child, 'nil', mark) == true then
- return true
+ if vm.isSubType(uri, child, 'nil', mark, err) == true then
+ return true, err
end
end
- return not hasKnownType
+ return not hasKnownType, err
end
---@cast parent vm.node.object
@@ -250,51 +315,54 @@ function vm.isSubType(uri, child, parent, mark)
local parentName = getNodeName(parent)
if parentName == 'any'
or parentName == 'unknown' then
- return true
+ return true, err
end
if not parentName
or isAlias(parentName, uri) then
- return nil
+ return nil, err
end
if childName == parentName then
- if not checkValue(parent, child) then
- return false
+ if not checkValue(parent, child, mark, err) then
+ return false, err
end
- return true
+ return true, err
end
if parentName == 'number' and childName == 'integer' then
- return true
+ return true, err
end
if parentName == 'integer' and childName == 'number' then
if config.get(uri, 'Lua.type.castNumberToInteger') then
- return true
+ return true, err
end
if child.type == 'number'
and child[1]
and not math.tointeger(child[1]) then
- return false
+ err[#err+1] = 'TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER'
+ err[#err+1] = child
+ return false, err
end
if child.type == 'global'
and child.cate == 'type' then
- return false
+ err[#err+1] = 'TYPE_ERROR_NUMBER_TYPE_TO_INTEGER'
+ return false, err
end
- return true
+ return true, err
end
- local isEnum = checkEnum(parentName, child, uri)
+ local isEnum = checkEnum(parentName, child, uri, err)
if isEnum ~= nil then
- return isEnum
+ return isEnum, err
end
if parentName == 'table' and not guide.isBasicType(childName) then
- return true
+ return true, err
end
if childName == 'table' and not guide.isBasicType(parentName) then
- return true
+ return true, err
end
-- check class parent
@@ -308,8 +376,8 @@ function vm.isSubType(uri, child, parent, mark)
for _, ext in ipairs(set.extends) do
if ext.type == 'doc.extends.name'
and (not isBasicType or guide.isBasicType(ext[1]))
- and vm.isSubType(uri, ext[1], parent, mark) == true then
- return true
+ and vm.isSubType(uri, ext[1], parent, mark, err) == true then
+ return true, err
end
end
end
@@ -326,11 +394,14 @@ function vm.isSubType(uri, child, parent, mark)
]]
if guide.isBasicType(childName)
and guide.isLiteral(child)
- and vm.isSubType(uri, parentName, childName) then
- return true
+ and vm.isSubType(uri, parentName, childName, mark, err) then
+ return true, err
end
- return false
+ err[#err+1] = 'TYPE_ERROR_DISMATCH'
+ err[#err+1] = child
+ err[#err+1] = parent
+ return false, err
end
---@param node string|vm.node|vm.object
@@ -474,6 +545,7 @@ end
---@param defNode vm.node
---@param refNode vm.node
---@return boolean
+---@return typecheck.err[]?
function vm.canCastType(uri, defNode, refNode)
local defInfer = vm.getInfer(defNode)
local refInfer = vm.getInfer(refNode)
@@ -512,9 +584,78 @@ function vm.canCastType(uri, defNode, refNode)
end
end
- if vm.isSubType(uri, refNode, defNode) then
+ local suc, err = vm.isSubType(uri, refNode, defNode)
+
+ if suc then
return true
end
- return false
+ return false, err
+end
+
+local ErrorMessageMap = {
+ TYPE_ERROR_ENUM_GLOBAL_DISMATCH = {'child', 'parent'},
+ TYPE_ERROR_ENUM_GENERIC_UNSUPPORTED = {'child'},
+ TYPE_ERROR_ENUM_LITERAL_DISMATCH = {'child', 'parent'},
+ TYPE_ERROR_ENUM_OBJECT_DISMATCH = {'child', 'parent'},
+ TYPE_ERROR_ENUM_NO_OBJECT = {'child'},
+ TYPE_ERROR_INTEGER_DISMATCH = {'child', 'parent'},
+ TYPE_ERROR_STRING_DISMATCH = {'child', 'parent'},
+ TYPE_ERROR_BOOLEAN_DISMATCH = {'child', 'parent'},
+ TYPE_ERROR_TABLE_NO_FIELD = {'key'},
+ TYPE_ERROR_TABLE_FIELD_DISMATCH = {'key', 'child', 'parent'},
+ TYPE_ERROR_UNION_ALL_DISMATCH = {'child', 'parent'},
+ TYPE_ERROR_UNION_DISMATCH = {'child', 'parent'},
+ TYPE_ERROR_OPTIONAL_DISMATCH = {'parent'},
+ TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER = {'child'},
+ TYPE_ERROR_NUMBER_TYPE_TO_INTEGER = {},
+ TYPE_ERROR_DISMATCH = {'child', 'parent'},
+}
+
+---@param uri uri
+---@param errs typecheck.err[]
+---@return string
+function vm.viewTypeErrorMessage(uri, errs)
+ local lines = {}
+ local index = 1
+ while true do
+ local name = errs[index]
+ if not name then
+ break
+ end
+ index = index + 1
+ local params = ErrorMessageMap[name]
+ local lparams = {}
+ for _, paramName in ipairs(params) do
+ local value = errs[index]
+ if type(value) == 'string'
+ or type(value) == 'number'
+ or type(value) == 'boolean' then
+ lparams[paramName] = util.viewLiteral(value)
+ elseif value.type == 'global' then
+ lparams[paramName] = value.name
+ elseif value.type == 'vm.node' then
+ ---@cast value vm.node
+ lparams[paramName] = vm.getInfer(value):view(uri)
+ elseif value.type == 'table' then
+ lparams[paramName] = 'table'
+ elseif value.type == 'generic' then
+ ---@cast value vm.generic
+ lparams[paramName] = vm.viewObject(value, uri)
+ else
+ ---@cast value -string, -vm.global, -vm.node, -vm.generic
+ if paramName == 'key' then
+ lparams[paramName] = vm.viewKey(value, uri)
+ else
+ lparams[paramName] = vm.viewObject(value, uri)
+ or vm.getInfer(value):view(uri)
+ end
+ end
+ index = index + 1
+ end
+ local line = lang.script(name, lparams)
+ lines[#lines+1] = '- ' .. line
+ end
+ util.revertTable(lines)
+ return table.concat(lines, '\n')
end