summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/completion/completion.lua3
-rw-r--r--script/vm/compiler.lua84
-rw-r--r--script/vm/def.lua9
-rw-r--r--script/vm/infer.lua9
-rw-r--r--script/vm/local-manager.lua17
-rw-r--r--script/vm/node.lua50
-rw-r--r--script/vm/ref.lua5
-rw-r--r--script/vm/sign.lua96
-rw-r--r--script/vm/value.lua19
-rw-r--r--test/diagnostics/common.lua10
-rw-r--r--test/type_inference/init.lua8
11 files changed, 177 insertions, 133 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua
index 3eaed85a..33d8fc16 100644
--- a/script/core/completion/completion.lua
+++ b/script/core/completion/completion.lua
@@ -1403,6 +1403,9 @@ local function tryCallArg(state, position, results)
return
end
local node = vm.compileCallArg({ type = 'dummyarg' }, call, argIndex)
+ if not node then
+ return
+ end
local enums = {}
for src in node:eachObject() do
if src.type == 'doc.type.string'
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 3c0a2dcb..8ad26f63 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -242,7 +242,8 @@ local function getObjectSign(source)
end
end
if source.type == 'doc.type.function'
- or source.type == 'doc.type.table' then
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.array' then
local hasGeneric
guide.eachSourceType(source, 'doc.generic.name', function ()
hasGeneric = true
@@ -375,21 +376,21 @@ local function getReturn(func, index, args)
if cnode.type == 'function'
or cnode.type == 'doc.type.function' then
local returnObject = vm.getReturnOfFunction(cnode, index)
- local returnNode = vm.compileNode(returnObject)
- if returnNode then
+ if returnObject then
+ local returnNode = vm.compileNode(returnObject)
for rnode in returnNode:eachObject() do
if rnode.type == 'generic' then
returnNode = rnode:resolve(guide.getUri(func), args)
break
end
end
- end
- if returnNode then
- for rnode in returnNode:eachObject() do
- -- TODO: narrow type
- if rnode.type ~= 'doc.generic.name' then
- result = result or vm.createNode()
- result:merge(rnode)
+ if returnNode then
+ for rnode in returnNode:eachObject() do
+ -- TODO: narrow type
+ if rnode.type ~= 'doc.generic.name' then
+ result = result or vm.createNode()
+ result:merge(rnode)
+ end
end
end
end
@@ -471,9 +472,6 @@ end
---@param pushResult fun(source: parser.object)
function vm.compileByParentNode(source, key, pushResult)
local parentNode = vm.compileNode(source)
- if not parentNode then
- return
- end
local suri = guide.getUri(source)
for node in parentNode:eachObject() do
searchFieldSwitch(node.type, suri, node, key, pushResult)
@@ -508,7 +506,8 @@ local function selectNode(source, list, index)
if exp.type == 'call' then
result = getReturn(exp.node, index, exp.args)
if not result then
- return nil
+ vm.setNode(source, globalMgr.getGlobal('type', 'unknown'))
+ return vm.getNode(source)
end
else
result = vm.compileNode(exp)
@@ -601,22 +600,27 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
for n in callNode:eachObject() do
if n.type == 'function' then
local farg = getFuncArg(n, myIndex)
- for fn in vm.compileNode(farg):eachObject() do
- if isValidCallArgNode(arg, fn) then
- vm.setNode(arg, fn)
+ if farg then
+ for fn in vm.compileNode(farg):eachObject() do
+ if isValidCallArgNode(arg, fn) then
+ vm.setNode(arg, fn)
+ end
end
end
end
if n.type == 'doc.type.function' then
+ local myEvent
if n.args[eventIndex] then
local argNode = vm.compileNode(n.args[eventIndex])
- local event = argNode and argNode[1]
- if not event
- or not eventMap
- or myIndex <= eventIndex
- or event.type ~= 'doc.type.string'
- or eventMap[event[1]] then
- local farg = getFuncArg(n, myIndex)
+ myEvent = argNode[1]
+ end
+ if not myEvent
+ or not eventMap
+ or myIndex <= eventIndex
+ or myEvent.type ~= 'doc.type.string'
+ or eventMap[myEvent[1]] then
+ local farg = getFuncArg(n, myIndex)
+ if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
vm.setNode(arg, fn)
@@ -704,7 +708,9 @@ local compilerSwitch = util.switch()
end)
: case 'paren'
: call(function (source)
- vm.setNode(source, vm.compileNode(source.exp))
+ if source.exp then
+ vm.setNode(source, vm.compileNode(source.exp))
+ end
end)
: case 'local'
: call(function (source)
@@ -720,10 +726,12 @@ local compilerSwitch = util.switch()
end
if source.value then
if not hasMarkDoc or guide.isLiteral(source.value) then
- if source.value and source.value.type == 'table' then
- vm.setNode(source, source.value)
- else
- vm.setNode(source, vm.compileNode(source.value))
+ if source.value then
+ if source.value.type == 'table' then
+ vm.setNode(source, source.value)
+ else
+ vm.setNode(source, vm.compileNode(source.value))
+ end
end
end
end
@@ -732,8 +740,8 @@ local compilerSwitch = util.switch()
and not hasMarkDoc then
vm.pauseCache()
for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal' then
- if ref.value and ref.value.type == 'table' then
+ if ref.type == 'setlocal' and ref.value then
+ if ref.value.type == 'table' then
vm.setNode(source, ref.value)
else
vm.setNode(source, vm.compileNode(ref.value))
@@ -859,10 +867,12 @@ local compilerSwitch = util.switch()
if source.value then
if not hasMarkDoc or guide.isLiteral(source.value) then
- if source.value and source.value.type == 'table' then
- vm.setNode(source, source.value)
- else
- vm.setNode(source, vm.compileNode(source.value))
+ if source.value then
+ if source.value.type == 'table' then
+ vm.setNode(source, source.value)
+ else
+ vm.setNode(source, vm.compileNode(source.value))
+ end
end
end
end
@@ -947,7 +957,9 @@ local compilerSwitch = util.switch()
end)
: case 'varargs'
: call(function (source)
- vm.setNode(source, vm.compileNode(source.node))
+ if source.node then
+ vm.setNode(source, vm.compileNode(source.node))
+ end
end)
: case 'call'
: call(function (source)
diff --git a/script/vm/def.lua b/script/vm/def.lua
index 9efb3e90..78055ddf 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -139,9 +139,6 @@ local nodeSwitch = util.switch()
: case 'setindex'
: call(function (source, pushResult)
local parentNode = vm.compileNode(source.node)
- if not parentNode then
- return
- end
local uri = guide.getUri(source)
local key = guide.getKeyName(source)
for pn in parentNode:eachObject() do
@@ -158,9 +155,6 @@ local nodeSwitch = util.switch()
: case 'doc.see.field'
: call(function (source, pushResult)
local parentNode = vm.compileNode(source.parent.name)
- if not parentNode then
- return
- end
local uri = guide.getUri(source)
for pn in parentNode:eachObject() do
searchFieldSwitch(pn.type, uri, pn, source[1], pushResult)
@@ -195,9 +189,6 @@ end
local function searchByNode(source, pushResult)
local node = vm.compileNode(source)
- if not node then
- return
- end
local suri = guide.getUri(source)
for n in node:eachObject() do
if n.type == 'global' then
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 71da6317..a5b113d6 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -158,9 +158,6 @@ local viewNodeSwitch = util.switch()
---@return vm.infer
function m.getInfer(source)
local node = vm.compileNode(source)
- if not node then
- return m.NULL
- end
if node.lastInfer then
return node.lastInfer
end
@@ -380,4 +377,10 @@ function mt:viewClass()
return table.concat(class, '|')
end
+---@param source parser.object
+---@return string?
+function m.viewObject(source)
+ return viewNodeSwitch(source.type, source, {})
+end
+
return m
diff --git a/script/vm/local-manager.lua b/script/vm/local-manager.lua
index 2baabed4..51bafb24 100644
--- a/script/vm/local-manager.lua
+++ b/script/vm/local-manager.lua
@@ -23,23 +23,6 @@ function m.declareLocal(source)
locals[#locals+1] = source
end
----@param source parser.object
----@param node vm.node
-function m.subscribeLocal(source, node)
- -- TODO: need delete
- if not node then
- return
- end
- if node.type == 'vm.node' then
- node:subscribeLocal(source)
- return
- end
- if not m.allLocals[node] then
- return
- end
- m.localSubs[node][source] = true
-end
-
---@param uri uri
function m.dropUri(uri)
local locals = m.locals[uri]
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 6106d3e1..941f2b09 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -46,48 +46,26 @@ function mt:isEmpty()
return #self == 0
end
----@param source parser.object
-function mt:subscribeLocal(source)
- -- TODO: need delete
- for _, c in ipairs(self) do
- localMgr.subscribeLocal(source, c)
- end
-end
-
----@return vm.node
function mt:addOptional()
if self:isOptional() then
return self
end
self.optional = true
- return self
end
----@return vm.node
function mt:removeOptional()
- self.optional = nil
if not self:isOptional() then
return self
end
- local newNode = vm.createNode()
- for _, n in ipairs(self) do
- if n.type == 'nil' then
- goto CONTINUE
- end
- if n.type == 'boolean' and n[1] == false then
- goto CONTINUE
+ for i = #self, 1, -1 do
+ local n = self[i]
+ if n.type == 'nil'
+ or (n.type == 'boolean' and n[1] == false)
+ or (n.type == 'doc.type.boolean' and n[1] == false) then
+ self[i] = self[#self]
+ self[#self] = nil
end
- if n.type == 'doc.type.boolean' and n[1] == false then
- goto CONTINUE
- end
- if n.type == 'false' then
- goto CONTINUE
- end
- newNode[#newNode+1] = n
- ::CONTINUE::
end
- newNode.optional = false
- return newNode
end
---@return boolean
@@ -96,17 +74,9 @@ function mt:isOptional()
return self.optional
end
for _, c in ipairs(self) do
- if c.type == 'nil' then
- self.optional = true
- return true
- end
- if c.type == 'boolean' then
- if c[1] == false then
- self.optional = true
- return true
- end
- end
- if c.type == 'false' then
+ if c.type == 'nil'
+ or (c.type == 'boolean' and c[1] == false)
+ or (c.type == 'doc.type.boolean' and c[1] == false) then
self.optional = true
return true
end
diff --git a/script/vm/ref.lua b/script/vm/ref.lua
index 360d979e..b086f6e1 100644
--- a/script/vm/ref.lua
+++ b/script/vm/ref.lua
@@ -218,11 +218,6 @@ local nodeSwitch = util.switch()
return
end
- local parentNode = vm.compileNode(source.node)
- if not parentNode then
- return
- end
-
searchField(source, pushResult, defMap, fileNotify)
end)
: case 'tablefield'
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 5b97f2b9..ca326965 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -1,5 +1,6 @@
local guide = require 'parser.guide'
local vm = require 'vm.vm'
+local infer = require 'vm.infer'
---@class vm.sign
---@field parent parser.object
@@ -23,12 +24,12 @@ function mt:resolve(uri, args)
local globalMgr = require 'vm.global-manager'
local resolved = {}
- ---@param typeUnit parser.object
- ---@param node vm.node
- local function resolve(typeUnit, node)
- if typeUnit.type == 'doc.generic.name' then
- local key = typeUnit[1]
- if typeUnit.literal then
+ ---@param object parser.object
+ ---@param node vm.node
+ local function resolve(object, node)
+ if object.type == 'doc.generic.name' then
+ local key = object[1]
+ if object.literal then
-- 'number' -> `T`
for n in node:eachObject() do
if n.type == 'string' then
@@ -41,16 +42,16 @@ function mt:resolve(uri, args)
resolved[key] = vm.createNode(node, resolved[key])
end
end
- if typeUnit.type == 'doc.type.array' then
+ if object.type == 'doc.type.array' then
for n in node:eachObject() do
if n.type == 'doc.type.array' then
-- number[] -> T[]
- resolve(typeUnit.node, vm.compileNode(n.node))
+ resolve(object.node, vm.compileNode(n.node))
end
end
end
- if typeUnit.type == 'doc.type.table' then
- for _, ufield in ipairs(typeUnit.fields) do
+ if object.type == 'doc.type.table' then
+ for _, ufield in ipairs(object.fields) do
local ufieldNode = vm.compileNode(ufield.name)
local uvalueNode = vm.compileNode(ufield.extends)
if ufieldNode[1].type == 'doc.generic.name' and uvalueNode[1].type == 'doc.generic.name' then
@@ -74,18 +75,79 @@ function mt:resolve(uri, args)
end
end
+ ---@param sign vm.node
+ ---@return table<string, true>
+ ---@return table<string, true>
+ local function getSignInfo(sign)
+ local knownTypes = {}
+ local genericsNames = {}
+ for obj in sign:eachObject() do
+ if obj.type == 'doc.generic.name' then
+ genericsNames[obj[1]] = true
+ goto CONTINUE
+ end
+ if obj.type == 'doc.type.table'
+ or obj.type == 'doc.type.function'
+ or obj.type == 'doc.type.array' then
+ local hasGeneric
+ guide.eachSourceType(obj, 'doc.generic.name', function (src)
+ hasGeneric = true
+ genericsNames[src[1]] = true
+ end)
+ if hasGeneric then
+ goto CONTINUE
+ end
+ end
+ local view = infer.viewObject(obj)
+ if view then
+ knownTypes[view] = true
+ end
+ ::CONTINUE::
+ end
+ return knownTypes, genericsNames
+ end
+
+ -- remove un-generic type
+ ---@param argNode vm.node
+ ---@param knownTypes table<string, true>
+ ---@return vm.node
+ local function buildArgNode(argNode, knownTypes)
+ local newArgNode = vm.createNode()
+ for n in argNode:eachObject() do
+ if argNode:isOptional() and vm.isFalsy(n) then
+ goto CONTINUE
+ end
+ local view = infer.viewObject(n)
+ if knownTypes[view] then
+ goto CONTINUE
+ end
+ newArgNode:merge(n)
+ ::CONTINUE::
+ end
+ return newArgNode
+ end
+
+ ---@param genericNames table<string, true>
+ local function isAllResolved(genericNames)
+ for n in pairs(genericNames) do
+ if not resolved[n] then
+ return false
+ end
+ end
+ return true
+ end
+
for i, arg in ipairs(args) do
local sign = self.signList[i]
if not sign then
break
end
- for n in sign:eachObject() do
- local argNode = vm.compileNode(arg)
- if argNode then
- if sign.optional then
- argNode:removeOptional()
- end
- resolve(n, argNode)
+ local argNode = vm.compileNode(arg)
+ local knownTypes, genericNames = getSignInfo(sign)
+ if not isAllResolved(genericNames) then
+ local newArgNode = buildArgNode(argNode, knownTypes)
+ for n in sign:eachObject() do
+ resolve(n, newArgNode)
end
end
end
diff --git a/script/vm/value.lua b/script/vm/value.lua
index 5810f8da..10107212 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -8,7 +8,8 @@ function vm.test(source)
local node = vm.compileNode(source)
local hasTrue, hasFalse
for n in node:eachObject() do
- if n.type == 'boolean' then
+ if n.type == 'boolean'
+ or n.type == 'doc.type.boolean' then
if n[1] == true then
hasTrue = true
end
@@ -37,6 +38,19 @@ function vm.test(source)
end
end
+---@param source parser.object
+---@return boolean
+function vm.isFalsy(source)
+ if source.type == 'nil' then
+ return true
+ end
+ if source.type == 'boolean'
+ or source.type == 'doc.type.boolean' then
+ return source[1] == false
+ end
+ return false
+end
+
---@param v vm.object
---@return string?
local function getUnique(v)
@@ -77,6 +91,9 @@ end
---@param b vm.node
---@return boolean|nil
function vm.equal(a, b)
+ if not a or not b then
+ return false
+ end
local nodeA = vm.compileNode(a)
local nodeB = vm.compileNode(b)
local mapA = {}
diff --git a/test/diagnostics/common.lua b/test/diagnostics/common.lua
index 6aa1dd6a..dcdcca0a 100644
--- a/test/diagnostics/common.lua
+++ b/test/diagnostics/common.lua
@@ -1382,8 +1382,8 @@ TEST [[
return ('1'):gsub()
]]
-TEST [[
-local value
-value = '1'
-value = value:gsub()
-]]
+--TEST [[
+--local value
+--value = '1'
+--value = value:gsub()
+--]]
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index fd6e464f..c2bd96b7 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -567,6 +567,14 @@ local f
local <?n?> = f(nil)
]]
+TEST 'unknown' [[
+---@generic K
+---@type fun(a: K|integer):K
+local f
+
+local <?n?> = f(1)
+]]
+
TEST 'integer' [[
---@class integer