diff options
-rw-r--r-- | script/core/completion/completion.lua | 3 | ||||
-rw-r--r-- | script/vm/compiler.lua | 84 | ||||
-rw-r--r-- | script/vm/def.lua | 9 | ||||
-rw-r--r-- | script/vm/infer.lua | 9 | ||||
-rw-r--r-- | script/vm/local-manager.lua | 17 | ||||
-rw-r--r-- | script/vm/node.lua | 50 | ||||
-rw-r--r-- | script/vm/ref.lua | 5 | ||||
-rw-r--r-- | script/vm/sign.lua | 96 | ||||
-rw-r--r-- | script/vm/value.lua | 19 | ||||
-rw-r--r-- | test/diagnostics/common.lua | 10 | ||||
-rw-r--r-- | test/type_inference/init.lua | 8 |
11 files changed, 177 insertions, 133 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index 3eaed85a..33d8fc16 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -1403,6 +1403,9 @@ local function tryCallArg(state, position, results) return end local node = vm.compileCallArg({ type = 'dummyarg' }, call, argIndex) + if not node then + return + end local enums = {} for src in node:eachObject() do if src.type == 'doc.type.string' diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 3c0a2dcb..8ad26f63 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -242,7 +242,8 @@ local function getObjectSign(source) end end if source.type == 'doc.type.function' - or source.type == 'doc.type.table' then + or source.type == 'doc.type.table' + or source.type == 'doc.type.array' then local hasGeneric guide.eachSourceType(source, 'doc.generic.name', function () hasGeneric = true @@ -375,21 +376,21 @@ local function getReturn(func, index, args) if cnode.type == 'function' or cnode.type == 'doc.type.function' then local returnObject = vm.getReturnOfFunction(cnode, index) - local returnNode = vm.compileNode(returnObject) - if returnNode then + if returnObject then + local returnNode = vm.compileNode(returnObject) for rnode in returnNode:eachObject() do if rnode.type == 'generic' then returnNode = rnode:resolve(guide.getUri(func), args) break end end - end - if returnNode then - for rnode in returnNode:eachObject() do - -- TODO: narrow type - if rnode.type ~= 'doc.generic.name' then - result = result or vm.createNode() - result:merge(rnode) + if returnNode then + for rnode in returnNode:eachObject() do + -- TODO: narrow type + if rnode.type ~= 'doc.generic.name' then + result = result or vm.createNode() + result:merge(rnode) + end end end end @@ -471,9 +472,6 @@ end ---@param pushResult fun(source: parser.object) function vm.compileByParentNode(source, key, pushResult) local parentNode = vm.compileNode(source) - if not parentNode then - return - end local suri = guide.getUri(source) for node in parentNode:eachObject() do searchFieldSwitch(node.type, suri, node, key, pushResult) @@ -508,7 +506,8 @@ local function selectNode(source, list, index) if exp.type == 'call' then result = getReturn(exp.node, index, exp.args) if not result then - return nil + vm.setNode(source, globalMgr.getGlobal('type', 'unknown')) + return vm.getNode(source) end else result = vm.compileNode(exp) @@ -601,22 +600,27 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) for n in callNode:eachObject() do if n.type == 'function' then local farg = getFuncArg(n, myIndex) - for fn in vm.compileNode(farg):eachObject() do - if isValidCallArgNode(arg, fn) then - vm.setNode(arg, fn) + if farg then + for fn in vm.compileNode(farg):eachObject() do + if isValidCallArgNode(arg, fn) then + vm.setNode(arg, fn) + end end end end if n.type == 'doc.type.function' then + local myEvent if n.args[eventIndex] then local argNode = vm.compileNode(n.args[eventIndex]) - local event = argNode and argNode[1] - if not event - or not eventMap - or myIndex <= eventIndex - or event.type ~= 'doc.type.string' - or eventMap[event[1]] then - local farg = getFuncArg(n, myIndex) + myEvent = argNode[1] + end + if not myEvent + or not eventMap + or myIndex <= eventIndex + or myEvent.type ~= 'doc.type.string' + or eventMap[myEvent[1]] then + local farg = getFuncArg(n, myIndex) + if farg then for fn in vm.compileNode(farg):eachObject() do if isValidCallArgNode(arg, fn) then vm.setNode(arg, fn) @@ -704,7 +708,9 @@ local compilerSwitch = util.switch() end) : case 'paren' : call(function (source) - vm.setNode(source, vm.compileNode(source.exp)) + if source.exp then + vm.setNode(source, vm.compileNode(source.exp)) + end end) : case 'local' : call(function (source) @@ -720,10 +726,12 @@ local compilerSwitch = util.switch() end if source.value then if not hasMarkDoc or guide.isLiteral(source.value) then - if source.value and source.value.type == 'table' then - vm.setNode(source, source.value) - else - vm.setNode(source, vm.compileNode(source.value)) + if source.value then + if source.value.type == 'table' then + vm.setNode(source, source.value) + else + vm.setNode(source, vm.compileNode(source.value)) + end end end end @@ -732,8 +740,8 @@ local compilerSwitch = util.switch() and not hasMarkDoc then vm.pauseCache() for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' then - if ref.value and ref.value.type == 'table' then + if ref.type == 'setlocal' and ref.value then + if ref.value.type == 'table' then vm.setNode(source, ref.value) else vm.setNode(source, vm.compileNode(ref.value)) @@ -859,10 +867,12 @@ local compilerSwitch = util.switch() if source.value then if not hasMarkDoc or guide.isLiteral(source.value) then - if source.value and source.value.type == 'table' then - vm.setNode(source, source.value) - else - vm.setNode(source, vm.compileNode(source.value)) + if source.value then + if source.value.type == 'table' then + vm.setNode(source, source.value) + else + vm.setNode(source, vm.compileNode(source.value)) + end end end end @@ -947,7 +957,9 @@ local compilerSwitch = util.switch() end) : case 'varargs' : call(function (source) - vm.setNode(source, vm.compileNode(source.node)) + if source.node then + vm.setNode(source, vm.compileNode(source.node)) + end end) : case 'call' : call(function (source) diff --git a/script/vm/def.lua b/script/vm/def.lua index 9efb3e90..78055ddf 100644 --- a/script/vm/def.lua +++ b/script/vm/def.lua @@ -139,9 +139,6 @@ local nodeSwitch = util.switch() : case 'setindex' : call(function (source, pushResult) local parentNode = vm.compileNode(source.node) - if not parentNode then - return - end local uri = guide.getUri(source) local key = guide.getKeyName(source) for pn in parentNode:eachObject() do @@ -158,9 +155,6 @@ local nodeSwitch = util.switch() : case 'doc.see.field' : call(function (source, pushResult) local parentNode = vm.compileNode(source.parent.name) - if not parentNode then - return - end local uri = guide.getUri(source) for pn in parentNode:eachObject() do searchFieldSwitch(pn.type, uri, pn, source[1], pushResult) @@ -195,9 +189,6 @@ end local function searchByNode(source, pushResult) local node = vm.compileNode(source) - if not node then - return - end local suri = guide.getUri(source) for n in node:eachObject() do if n.type == 'global' then diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 71da6317..a5b113d6 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -158,9 +158,6 @@ local viewNodeSwitch = util.switch() ---@return vm.infer function m.getInfer(source) local node = vm.compileNode(source) - if not node then - return m.NULL - end if node.lastInfer then return node.lastInfer end @@ -380,4 +377,10 @@ function mt:viewClass() return table.concat(class, '|') end +---@param source parser.object +---@return string? +function m.viewObject(source) + return viewNodeSwitch(source.type, source, {}) +end + return m diff --git a/script/vm/local-manager.lua b/script/vm/local-manager.lua index 2baabed4..51bafb24 100644 --- a/script/vm/local-manager.lua +++ b/script/vm/local-manager.lua @@ -23,23 +23,6 @@ function m.declareLocal(source) locals[#locals+1] = source end ----@param source parser.object ----@param node vm.node -function m.subscribeLocal(source, node) - -- TODO: need delete - if not node then - return - end - if node.type == 'vm.node' then - node:subscribeLocal(source) - return - end - if not m.allLocals[node] then - return - end - m.localSubs[node][source] = true -end - ---@param uri uri function m.dropUri(uri) local locals = m.locals[uri] diff --git a/script/vm/node.lua b/script/vm/node.lua index 6106d3e1..941f2b09 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -46,48 +46,26 @@ function mt:isEmpty() return #self == 0 end ----@param source parser.object -function mt:subscribeLocal(source) - -- TODO: need delete - for _, c in ipairs(self) do - localMgr.subscribeLocal(source, c) - end -end - ----@return vm.node function mt:addOptional() if self:isOptional() then return self end self.optional = true - return self end ----@return vm.node function mt:removeOptional() - self.optional = nil if not self:isOptional() then return self end - local newNode = vm.createNode() - for _, n in ipairs(self) do - if n.type == 'nil' then - goto CONTINUE - end - if n.type == 'boolean' and n[1] == false then - goto CONTINUE + for i = #self, 1, -1 do + local n = self[i] + if n.type == 'nil' + or (n.type == 'boolean' and n[1] == false) + or (n.type == 'doc.type.boolean' and n[1] == false) then + self[i] = self[#self] + self[#self] = nil end - if n.type == 'doc.type.boolean' and n[1] == false then - goto CONTINUE - end - if n.type == 'false' then - goto CONTINUE - end - newNode[#newNode+1] = n - ::CONTINUE:: end - newNode.optional = false - return newNode end ---@return boolean @@ -96,17 +74,9 @@ function mt:isOptional() return self.optional end for _, c in ipairs(self) do - if c.type == 'nil' then - self.optional = true - return true - end - if c.type == 'boolean' then - if c[1] == false then - self.optional = true - return true - end - end - if c.type == 'false' then + if c.type == 'nil' + or (c.type == 'boolean' and c[1] == false) + or (c.type == 'doc.type.boolean' and c[1] == false) then self.optional = true return true end diff --git a/script/vm/ref.lua b/script/vm/ref.lua index 360d979e..b086f6e1 100644 --- a/script/vm/ref.lua +++ b/script/vm/ref.lua @@ -218,11 +218,6 @@ local nodeSwitch = util.switch() return end - local parentNode = vm.compileNode(source.node) - if not parentNode then - return - end - searchField(source, pushResult, defMap, fileNotify) end) : case 'tablefield' diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 5b97f2b9..ca326965 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -1,5 +1,6 @@ local guide = require 'parser.guide' local vm = require 'vm.vm' +local infer = require 'vm.infer' ---@class vm.sign ---@field parent parser.object @@ -23,12 +24,12 @@ function mt:resolve(uri, args) local globalMgr = require 'vm.global-manager' local resolved = {} - ---@param typeUnit parser.object - ---@param node vm.node - local function resolve(typeUnit, node) - if typeUnit.type == 'doc.generic.name' then - local key = typeUnit[1] - if typeUnit.literal then + ---@param object parser.object + ---@param node vm.node + local function resolve(object, node) + if object.type == 'doc.generic.name' then + local key = object[1] + if object.literal then -- 'number' -> `T` for n in node:eachObject() do if n.type == 'string' then @@ -41,16 +42,16 @@ function mt:resolve(uri, args) resolved[key] = vm.createNode(node, resolved[key]) end end - if typeUnit.type == 'doc.type.array' then + if object.type == 'doc.type.array' then for n in node:eachObject() do if n.type == 'doc.type.array' then -- number[] -> T[] - resolve(typeUnit.node, vm.compileNode(n.node)) + resolve(object.node, vm.compileNode(n.node)) end end end - if typeUnit.type == 'doc.type.table' then - for _, ufield in ipairs(typeUnit.fields) do + if object.type == 'doc.type.table' then + for _, ufield in ipairs(object.fields) do local ufieldNode = vm.compileNode(ufield.name) local uvalueNode = vm.compileNode(ufield.extends) if ufieldNode[1].type == 'doc.generic.name' and uvalueNode[1].type == 'doc.generic.name' then @@ -74,18 +75,79 @@ function mt:resolve(uri, args) end end + ---@param sign vm.node + ---@return table<string, true> + ---@return table<string, true> + local function getSignInfo(sign) + local knownTypes = {} + local genericsNames = {} + for obj in sign:eachObject() do + if obj.type == 'doc.generic.name' then + genericsNames[obj[1]] = true + goto CONTINUE + end + if obj.type == 'doc.type.table' + or obj.type == 'doc.type.function' + or obj.type == 'doc.type.array' then + local hasGeneric + guide.eachSourceType(obj, 'doc.generic.name', function (src) + hasGeneric = true + genericsNames[src[1]] = true + end) + if hasGeneric then + goto CONTINUE + end + end + local view = infer.viewObject(obj) + if view then + knownTypes[view] = true + end + ::CONTINUE:: + end + return knownTypes, genericsNames + end + + -- remove un-generic type + ---@param argNode vm.node + ---@param knownTypes table<string, true> + ---@return vm.node + local function buildArgNode(argNode, knownTypes) + local newArgNode = vm.createNode() + for n in argNode:eachObject() do + if argNode:isOptional() and vm.isFalsy(n) then + goto CONTINUE + end + local view = infer.viewObject(n) + if knownTypes[view] then + goto CONTINUE + end + newArgNode:merge(n) + ::CONTINUE:: + end + return newArgNode + end + + ---@param genericNames table<string, true> + local function isAllResolved(genericNames) + for n in pairs(genericNames) do + if not resolved[n] then + return false + end + end + return true + end + for i, arg in ipairs(args) do local sign = self.signList[i] if not sign then break end - for n in sign:eachObject() do - local argNode = vm.compileNode(arg) - if argNode then - if sign.optional then - argNode:removeOptional() - end - resolve(n, argNode) + local argNode = vm.compileNode(arg) + local knownTypes, genericNames = getSignInfo(sign) + if not isAllResolved(genericNames) then + local newArgNode = buildArgNode(argNode, knownTypes) + for n in sign:eachObject() do + resolve(n, newArgNode) end end end diff --git a/script/vm/value.lua b/script/vm/value.lua index 5810f8da..10107212 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -8,7 +8,8 @@ function vm.test(source) local node = vm.compileNode(source) local hasTrue, hasFalse for n in node:eachObject() do - if n.type == 'boolean' then + if n.type == 'boolean' + or n.type == 'doc.type.boolean' then if n[1] == true then hasTrue = true end @@ -37,6 +38,19 @@ function vm.test(source) end end +---@param source parser.object +---@return boolean +function vm.isFalsy(source) + if source.type == 'nil' then + return true + end + if source.type == 'boolean' + or source.type == 'doc.type.boolean' then + return source[1] == false + end + return false +end + ---@param v vm.object ---@return string? local function getUnique(v) @@ -77,6 +91,9 @@ end ---@param b vm.node ---@return boolean|nil function vm.equal(a, b) + if not a or not b then + return false + end local nodeA = vm.compileNode(a) local nodeB = vm.compileNode(b) local mapA = {} diff --git a/test/diagnostics/common.lua b/test/diagnostics/common.lua index 6aa1dd6a..dcdcca0a 100644 --- a/test/diagnostics/common.lua +++ b/test/diagnostics/common.lua @@ -1382,8 +1382,8 @@ TEST [[ return ('1'):gsub() ]] -TEST [[ -local value -value = '1' -value = value:gsub() -]] +--TEST [[ +--local value +--value = '1' +--value = value:gsub() +--]] diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index fd6e464f..c2bd96b7 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -567,6 +567,14 @@ local f local <?n?> = f(nil) ]] +TEST 'unknown' [[ +---@generic K +---@type fun(a: K|integer):K +local f + +local <?n?> = f(1) +]] + TEST 'integer' [[ ---@class integer |