summaryrefslogtreecommitdiff
path: root/script/vm/type.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm/type.lua')
-rw-r--r--script/vm/type.lua372
1 files changed, 316 insertions, 56 deletions
diff --git a/script/vm/type.lua b/script/vm/type.lua
index c3264993..d112be2c 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -1,67 +1,252 @@
---@class vm
local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local config = require 'config.config'
+local util = require 'utility'
----@param uri uri
----@param child vm.node|string
----@param parent vm.node|string
----@param mark? table
----@return boolean
-function vm.isSubType(uri, child, parent, mark)
- if type(parent) == 'string' then
- parent = vm.createNode(vm.getGlobal('type', parent))
+---@param object vm.node.object
+---@return string?
+local function getNodeName(object)
+ if object.type == 'global' and object.cate == 'type' then
+ ---@cast object vm.global
+ return object.name
end
- if type(child) == 'string' then
- child = vm.createNode(vm.getGlobal('type', child))
+ if object.type == 'nil'
+ or object.type == 'boolean'
+ or object.type == 'number'
+ or object.type == 'string'
+ or object.type == 'table'
+ or object.type == 'function'
+ or object.type == 'integer' then
+ return object.type
+ end
+ if object.type == 'doc.type.boolean' then
+ return 'boolean'
end
+ if object.type == 'doc.type.integer' then
+ return 'integer'
+ end
+ if object.type == 'doc.type.function' then
+ return 'function'
+ end
+ if object.type == 'doc.type.table' then
+ return 'table'
+ end
+ if object.type == 'doc.type.array' then
+ return 'table'
+ end
+ if object.type == 'doc.type.string' then
+ return 'string'
+ end
+ return nil
+end
- if not child or not parent then
- return false
+---@param parentName string
+---@param child vm.node.object
+---@param uri uri
+---@return boolean?
+local function checkEnum(parentName, child, uri)
+ local parentClass = vm.getGlobal('type', parentName)
+ if not parentClass then
+ return nil
+ end
+ for _, set in ipairs(parentClass:getSets(uri)) do
+ if set.type == 'doc.enum' then
+ if not set._enums then
+ return false
+ end
+ if child.type ~= 'string'
+ and child.type ~= 'doc.type.string'
+ and child.type ~= 'integer'
+ and child.type ~= 'number'
+ and child.type ~= 'doc.type.integer' then
+ return false
+ end
+ return util.arrayHas(set._enums, child[1])
+ end
+ end
+
+ return nil
+end
+
+---@param parent vm.node.object
+---@param child vm.node.object
+---@return boolean
+local function checkValue(parent, child)
+ if parent.type == 'doc.type.integer' then
+ if child.type == 'integer'
+ or child.type == 'doc.type.integer'
+ or child.type == 'number' then
+ return parent[1] == child[1]
+ end
+ elseif parent.type == 'doc.type.string' then
+ if child.type == 'string'
+ or child.type == 'doc.type.string' then
+ return parent[1] == child[1]
+ end
end
+ return true
+end
+
+---@param uri uri
+---@param child vm.node|string|vm.node.object
+---@param parent vm.node|string|vm.node.object
+---@param mark? table
+---@return boolean
+function vm.isSubType(uri, child, parent, mark)
mark = mark or {}
- for obj in child:eachObject() do
- if obj.type ~= 'global'
- or obj.cate ~= 'type' then
- goto CONTINUE_CHILD
+
+ if type(child) == 'string' then
+ local global = vm.getGlobal('type', child)
+ if not global then
+ return false
+ end
+ child = global
+ elseif child.type == 'vm.node' then
+ if config.get(uri, 'Lua.type.weakUnionCheck') then
+ local hasKnownType
+ for n in child:eachObject() do
+ if getNodeName(n) then
+ hasKnownType = true
+ if vm.isSubType(uri, n, parent, mark) then
+ return true
+ end
+ end
+ end
+ return not hasKnownType
+ else
+ local weakNil = config.get(uri, 'Lua.type.weakNilCheck')
+ for n in child:eachObject() do
+ local nodeName = getNodeName(n)
+ if nodeName
+ and not (nodeName == 'nil' and weakNil)
+ and not vm.isSubType(uri, n, parent, mark) then
+ return false
+ end
+ end
+ if not weakNil and child:isOptional() then
+ if not vm.isSubType(uri, 'nil', parent, mark) then
+ return false
+ end
+ end
+ return true
end
- if mark[obj.name] then
+ end
+
+ if type(parent) == 'string' then
+ local global = vm.getGlobal('type', parent)
+ if not global then
return false
end
- mark[obj.name] = true
- for parentNode in parent:eachObject() do
- if parentNode.type ~= 'global'
- or parentNode.cate ~= 'type' then
- goto CONTINUE_PARENT
+ parent = global
+ elseif parent.type == 'vm.node' then
+ for n in parent:eachObject() do
+ if getNodeName(n)
+ and vm.isSubType(uri, child, n, mark) then
+ return true
end
- if parentNode.name == 'any' or obj.name == 'any' then
+ if n.type == 'doc.generic.name' then
return true
end
-
- if parentNode.name == obj.name then
+ end
+ if parent:isOptional() then
+ if vm.isSubType(uri, child, 'nil', mark) then
return true
end
+ end
+ return false
+ end
+
+ ---@cast child vm.node.object
+ ---@cast parent vm.node.object
- for _, set in ipairs(obj:getSets(uri)) do
+ local childName = getNodeName(child)
+ local parentName = getNodeName(parent)
+ if childName == 'any'
+ or parentName == 'any'
+ or childName == 'unknown'
+ or parentName == 'unknown'
+ or not childName
+ or not parentName then
+ return true
+ end
+
+ if childName == parentName then
+ if not checkValue(parent, child) then
+ return false
+ end
+ return true
+ end
+
+ if parentName == 'number' and childName == 'integer' then
+ return true
+ end
+
+ if parentName == 'integer' and childName == 'number' then
+ if config.get(uri, 'Lua.type.castNumberToInteger') then
+ return true
+ end
+ if child.type == 'number'
+ and child[1]
+ and not math.tointeger(child[1]) then
+ return false
+ end
+ if child.type == 'global'
+ and child.cate == 'type' then
+ return false
+ end
+ return true
+ end
+
+ local isEnum = checkEnum(parentName, child, uri)
+ if isEnum ~= nil then
+ return isEnum
+ end
+
+ -- TODO: check duck
+ if parentName == 'table' and not guide.isBasicType(childName) then
+ return true
+ end
+ if childName == 'table' and not guide.isBasicType(parentName) then
+ return true
+ end
+
+ -- check class parent
+ if childName and not mark[childName] then
+ mark[childName] = true
+ local isBasicType = guide.isBasicType(childName)
+ local childClass = vm.getGlobal('type', childName)
+ if childClass then
+ for _, set in ipairs(childClass:getSets(uri)) do
if set.type == 'doc.class' and set.extends then
for _, ext in ipairs(set.extends) do
if ext.type == 'doc.extends.name'
- and vm.isSubType(uri, ext[1], parentNode.name, mark) then
+ and (not isBasicType or guide.isBasicType(ext[1]))
+ and vm.isSubType(uri, ext[1], parent, mark) then
return true
end
end
end
- if set.type == 'doc.alias' and set.extends then
- for _, ext in ipairs(set.extends.types) do
- if ext.type == 'doc.type.name'
- and vm.isSubType(uri, ext[1], parentNode.name, mark) then
- return true
- end
- end
+ if set.type == 'doc.alias'
+ or set.type == 'doc.enum' then
+ return true
end
end
- ::CONTINUE_PARENT::
end
- ::CONTINUE_CHILD::
+ mark[childName] = nil
+ end
+
+ --[[
+ ---@class A: string
+
+ ---@type A
+ local x = '' --> `string` set to `A`
+ ]]
+ if guide.isBasicType(childName)
+ and guide.isLiteral(child)
+ and vm.isSubType(uri, parentName, childName) then
+ return true
end
return false
@@ -69,16 +254,24 @@ end
---@param uri uri
---@param tnode vm.node
----@param knode vm.node
+---@param knode vm.node|string
+---@param inversion? boolean
---@return vm.node?
-function vm.getTableValue(uri, tnode, knode)
+function vm.getTableValue(uri, tnode, knode, inversion)
local result = vm.createNode()
for tn in tnode:eachObject() do
if tn.type == 'doc.type.table' then
for _, field in ipairs(tn.fields) do
- if vm.isSubType(uri, vm.compileNode(field.name), knode) then
- if field.extends then
- result:merge(vm.compileNode(field.extends))
+ if field.name.type ~= 'doc.field.name'
+ and field.extends then
+ if inversion then
+ if vm.isSubType(uri, vm.compileNode(field.name), knode) then
+ result:merge(vm.compileNode(field.extends))
+ end
+ else
+ if vm.isSubType(uri, knode, vm.compileNode(field.name)) then
+ result:merge(vm.compileNode(field.extends))
+ end
end
end
end
@@ -88,25 +281,38 @@ function vm.getTableValue(uri, tnode, knode)
end
if tn.type == 'table' then
for _, field in ipairs(tn) do
- if field.type == 'tableindex' then
- if field.value then
- result:merge(vm.compileNode(field.value))
- end
+ if field.type == 'tableindex'
+ and field.value then
+ result:merge(vm.compileNode(field.value))
end
- if field.type == 'tablefield' then
- if vm.isSubType(uri, knode, 'string') then
- if field.value then
+ if field.type == 'tablefield'
+ and field.value then
+ if inversion then
+ if vm.isSubType(uri, 'string', knode) then
+ result:merge(vm.compileNode(field.value))
+ end
+ else
+ if vm.isSubType(uri, knode, 'string') then
result:merge(vm.compileNode(field.value))
end
end
end
- if field.type == 'tableexp' then
- if vm.isSubType(uri, knode, 'integer') and field.tindex == 1 then
- if field.value then
+ if field.type == 'tableexp'
+ and field.value
+ and field.tindex == 1 then
+ if inversion then
+ if vm.isSubType(uri, 'integer', knode) then
+ result:merge(vm.compileNode(field.value))
+ end
+ else
+ if vm.isSubType(uri, knode, 'integer') then
result:merge(vm.compileNode(field.value))
end
end
end
+ if field.type == 'varargs' then
+ result:merge(vm.compileNode(field))
+ end
end
end
end
@@ -118,16 +324,24 @@ end
---@param uri uri
---@param tnode vm.node
----@param vnode vm.node
+---@param vnode vm.node|string|vm.object
+---@param reverse? boolean
---@return vm.node?
-function vm.getTableKey(uri, tnode, vnode)
+function vm.getTableKey(uri, tnode, vnode, reverse)
local result = vm.createNode()
for tn in tnode:eachObject() do
if tn.type == 'doc.type.table' then
for _, field in ipairs(tn.fields) do
- if field.extends then
- if vm.isSubType(uri, vm.compileNode(field.extends), vnode) then
- result:merge(vm.compileNode(field.name))
+ if field.name.type ~= 'doc.field.name'
+ and field.extends then
+ if reverse then
+ if vm.isSubType(uri, vm.compileNode(field.extends), vnode) then
+ result:merge(vm.compileNode(field.name))
+ end
+ else
+ if vm.isSubType(uri, vnode, vm.compileNode(field.extends)) then
+ result:merge(vm.compileNode(field.name))
+ end
end
end
end
@@ -156,3 +370,49 @@ function vm.getTableKey(uri, tnode, vnode)
end
return result
end
+
+---@param uri uri
+---@param defNode vm.node
+---@param refNode vm.node
+---@return boolean
+function vm.canCastType(uri, defNode, refNode)
+ local defInfer = vm.getInfer(defNode)
+ local refInfer = vm.getInfer(refNode)
+
+ if defInfer:hasAny(uri) then
+ return true
+ end
+ if refInfer:hasAny(uri) then
+ return true
+ end
+ if defInfer:view(uri) == 'unknown' then
+ return true
+ end
+ if refInfer:view(uri) == 'unknown' then
+ return true
+ end
+
+ if vm.isSubType(uri, refNode, 'nil') then
+ -- allow `local x = {};x = nil`,
+ -- but not allow `local x ---@type table;x = nil`
+ if defInfer:hasType(uri, 'table')
+ and not defNode:hasType 'table' then
+ return true
+ end
+ end
+
+ if vm.isSubType(uri, refNode, 'number') then
+ -- allow `local x = 0;x = 1.0`,
+ -- but not allow `local x ---@type integer;x = 1.0`
+ if defInfer:hasType(uri, 'integer')
+ and not defNode:hasType 'integer' then
+ return true
+ end
+ end
+
+ if vm.isSubType(uri, refNode, defNode) then
+ return true
+ end
+
+ return false
+end