summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2022-06-14 21:01:25 +0800
committer最萌小汐 <sumneko@hotmail.com>2022-06-14 21:01:25 +0800
commitd1567d7e47a7c418d7fa072e41e78b2c8cbf29a5 (patch)
tree89abe580e7b61aae80b916ee83cfaf0d71f6ade9
parent11085b8e4172f054191d3cb9cffac3cbce1803f5 (diff)
parent329cd44aedec15f1048742d98baacdd71d434a6e (diff)
downloadlua-language-server-d1567d7e47a7c418d7fa072e41e78b2c8cbf29a5.zip
Merge branch 'runner'
-rw-r--r--script/core/completion/completion.lua4
-rw-r--r--script/core/diagnostics/need-check-nil.lua2
-rw-r--r--script/core/hover/description.lua9
-rw-r--r--script/vm/compiler.lua3
-rw-r--r--script/vm/infer.lua5
-rw-r--r--script/vm/node.lua30
-rw-r--r--script/vm/runner-bak.lua444
-rw-r--r--script/vm/runner.lua590
-rw-r--r--script/vm/sign.lua4
-rw-r--r--test/type_inference/init.lua116
10 files changed, 818 insertions, 389 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua
index d7c210c6..285f5bf2 100644
--- a/script/core/completion/completion.lua
+++ b/script/core/completion/completion.lua
@@ -1132,7 +1132,7 @@ local function checkTypingEnum(state, position, defs, str, results)
or def.type == 'doc.type.integer'
or def.type == 'doc.type.boolean' then
enums[#enums+1] = {
- label = vm.viewObject(def),
+ label = vm.viewObject(def, state.uri),
description = def.comment and def.comment.text,
kind = define.CompletionItemKind.EnumMember,
}
@@ -1427,7 +1427,7 @@ local function tryCallArg(state, position, results)
or src.type == 'doc.type.integer'
or src.type == 'doc.type.boolean' then
enums[#enums+1] = {
- label = vm.viewObject(src),
+ label = vm.viewObject(src, state.uri),
description = src.comment,
kind = define.CompletionItemKind.EnumMember,
}
diff --git a/script/core/diagnostics/need-check-nil.lua b/script/core/diagnostics/need-check-nil.lua
index 98fdfd08..56cb1eae 100644
--- a/script/core/diagnostics/need-check-nil.lua
+++ b/script/core/diagnostics/need-check-nil.lua
@@ -28,7 +28,7 @@ return function (uri, callback)
return
end
local node = vm.compileNode(src)
- if node:hasFalsy() then
+ if node:hasFalsy() and not vm.getInfer(src):hasType(uri, 'any') then
callback {
start = src.start,
finish = src.finish,
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
index 3fef1a21..c96aaae3 100644
--- a/script/core/hover/description.lua
+++ b/script/core/hover/description.lua
@@ -144,7 +144,7 @@ local function tryDocModule(source)
return collectRequire('require', source.module, guide.getUri(source))
end
-local function buildEnumChunk(docType, name)
+local function buildEnumChunk(docType, name, uri)
if not docType then
return nil
end
@@ -175,7 +175,7 @@ local function buildEnumChunk(docType, name)
(enum.default and '->')
or (enum.additional and '+>')
or ' |',
- vm.viewObject(enum)
+ vm.viewObject(enum, uri)
)
if enum.comment then
local first = true
@@ -199,6 +199,7 @@ local function getBindEnums(source, docGroup)
return
end
+ local uri = guide.getUri(source)
local mark = {}
local chunks = {}
local returnIndex = 0
@@ -209,7 +210,7 @@ local function getBindEnums(source, docGroup)
goto CONTINUE
end
mark[name] = true
- chunks[#chunks+1] = buildEnumChunk(doc.extends, name)
+ chunks[#chunks+1] = buildEnumChunk(doc.extends, name, uri)
elseif doc.type == 'doc.return' then
for _, rtn in ipairs(doc.returns) do
returnIndex = returnIndex + 1
@@ -218,7 +219,7 @@ local function getBindEnums(source, docGroup)
goto CONTINUE
end
mark[name] = true
- chunks[#chunks+1] = buildEnumChunk(rtn, name)
+ chunks[#chunks+1] = buildEnumChunk(rtn, name, uri)
end
end
::CONTINUE::
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 9b871553..0cbe2b7a 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -998,8 +998,7 @@ local compilerSwitch = util.switch()
local hasMark = vm.getNode(source):getData 'hasDefined'
- local runner = vm.createRunner(source)
- runner:launch(function (src, node)
+ vm.launchRunner(source, function (src, node)
if src.type == 'setlocal' then
if src.bindDocs then
for _, doc in ipairs(src.bindDocs) do
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 712cff10..9bcd3963 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -417,7 +417,8 @@ function mt:viewClass()
end
---@param source parser.object
+---@param uri uri
---@return string?
-function vm.viewObject(source)
- return viewNodeSwitch(source.type, source, {}, guide.getUri(source))
+function vm.viewObject(source, uri)
+ return viewNodeSwitch(source.type, source, {}, uri)
end
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 6a86350a..5086f66d 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -2,6 +2,7 @@ local files = require 'files'
---@class vm
local vm = require 'vm.vm'
local ws = require 'workspace.workspace'
+local guide = require 'parser.guide'
---@type table<vm.object, vm.node>
vm.nodeCache = {}
@@ -105,6 +106,18 @@ function mt:hasFalsy()
return false
end
+function mt:hasKnownType()
+ for _, c in ipairs(self) do
+ if c.type == 'global' and c.cate == 'type' then
+ return true
+ end
+ if guide.isLiteral(c) then
+ return true
+ end
+ end
+ return false
+end
+
---@return boolean
function mt:isNullable()
if self.optional then
@@ -152,6 +165,7 @@ function mt:setTruthy()
if hasBoolean then
self[#self+1] = vm.declareGlobal('type', 'true')
end
+ return self
end
---@return vm.node
@@ -174,12 +188,19 @@ function mt:setFalsy()
hasBoolean = true
table.remove(self, index)
self[c] = nil
+ goto CONTINUE
+ end
+ if (c.type == 'global' and c.cate == 'type') then
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
end
::CONTINUE::
end
if hasBoolean then
self[#self+1] = vm.declareGlobal('type', 'false')
end
+ return self
end
---@param name string
@@ -200,6 +221,7 @@ function mt:remove(name)
self[c] = nil
end
end
+ return self
end
---@param node vm.node
@@ -207,6 +229,14 @@ function mt:removeNode(node)
for _, c in ipairs(node) do
if c.type == 'global' and c.cate == 'type' then
self:remove(c.name)
+ elseif c.type == 'nil' then
+ self:remove 'nil'
+ elseif c.type == 'boolean' then
+ if c[1] == true then
+ self:remove 'true'
+ else
+ self:remove 'false'
+ end
end
end
end
diff --git a/script/vm/runner-bak.lua b/script/vm/runner-bak.lua
new file mode 100644
index 00000000..5535f0ca
--- /dev/null
+++ b/script/vm/runner-bak.lua
@@ -0,0 +1,444 @@
+---@class vm
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+
+---@class vm.runner-bak
+---@field loc parser.object
+---@field mainBlock parser.object
+---@field blocks table<parser.object, true>
+---@field steps vm.runner.step[]
+local mt = {}
+mt.__index = mt
+mt.index = 1
+
+---@class parser.object
+---@field _casts parser.object[]
+
+---@class vm.runner.step
+---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast'
+---@field pos integer
+---@field order? integer
+---@field node? vm.node
+---@field object? parser.object
+---@field name? string
+---@field cast? parser.object
+---@field tag? string
+---@field copy? boolean
+---@field new? boolean
+---@field ref1? vm.runner.step
+---@field ref2? vm.runner.step
+
+---@param filter parser.object
+---@param outStep vm.runner.step
+---@param blockStep vm.runner.step
+function mt:_compileNarrowByFilter(filter, outStep, blockStep)
+ if not filter then
+ return
+ end
+ if filter.type == 'paren' then
+ if filter.exp then
+ self:_compileNarrowByFilter(filter.exp, outStep, blockStep)
+ end
+ return
+ end
+ if filter.type == 'unary' then
+ if not filter.op
+ or not filter[1] then
+ return
+ end
+ if filter.op.type == 'not' then
+ local exp = filter[1]
+ if exp.type == 'getlocal' and exp.node == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'falsy',
+ pos = filter.finish,
+ new = true,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'truthy',
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ end
+ end
+ elseif filter.type == 'binary' then
+ if not filter.op
+ or not filter[1]
+ or not filter[2] then
+ return
+ end
+ if filter.op.type == 'and' then
+ local dummyStep = {
+ type = 'save',
+ copy = true,
+ ref1 = outStep,
+ pos = filter.start - 1,
+ }
+ self.steps[#self.steps+1] = dummyStep
+ self:_compileNarrowByFilter(filter[1], dummyStep, blockStep)
+ self:_compileNarrowByFilter(filter[2], dummyStep, blockStep)
+ end
+ if filter.op.type == 'or' then
+ self:_compileNarrowByFilter(filter[1], outStep, blockStep)
+ local dummyStep = {
+ type = 'push',
+ copy = true,
+ ref1 = outStep,
+ pos = filter.op.finish,
+ }
+ self.steps[#self.steps+1] = dummyStep
+ self:_compileNarrowByFilter(filter[2], outStep, dummyStep)
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'or reset',
+ ref1 = blockStep,
+ pos = filter.finish,
+ }
+ end
+ if filter.op.type == '=='
+ or filter.op.type == '~=' then
+ local loc, exp
+ for i = 1, 2 do
+ loc = filter[i]
+ if loc.type == 'getlocal' and loc.node == self.loc then
+ exp = filter[i % 2 + 1]
+ break
+ end
+ end
+ if not loc or not exp then
+ return
+ end
+ if guide.isLiteral(exp) then
+ if filter.op.type == '==' then
+ self.steps[#self.steps+1] = {
+ type = 'remove',
+ name = exp.type,
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'as',
+ name = exp.type,
+ pos = filter.finish,
+ new = true,
+ }
+ end
+ if filter.op.type == '~=' then
+ self.steps[#self.steps+1] = {
+ type = 'as',
+ name = exp.type,
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'remove',
+ name = exp.type,
+ pos = filter.finish,
+ new = true,
+ }
+ end
+ end
+ end
+ else
+ if filter.type == 'getlocal' and filter.node == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'truthy',
+ pos = filter.finish,
+ new = true,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'falsy',
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ end
+ end
+end
+
+---@param block parser.object
+function mt:_compileBlock(block)
+ if self.blocks[block] then
+ return
+ end
+ self.blocks[block] = true
+ if block == self.mainBlock then
+ return
+ end
+
+ local parentBlock = guide.getParentBlock(block)
+ self:_compileBlock(parentBlock)
+
+ if block.type == 'if' then
+ ---@type vm.runner.step[]
+ local finals = {}
+ for i, childBlock in ipairs(block) do
+ local blockStep = {
+ type = 'save',
+ tag = 'block',
+ copy = true,
+ pos = childBlock.start,
+ }
+ local outStep = {
+ type = 'save',
+ tag = 'out',
+ copy = true,
+ pos = childBlock.start,
+ }
+ self.steps[#self.steps+1] = blockStep
+ self.steps[#self.steps+1] = outStep
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ ref1 = blockStep,
+ pos = childBlock.start,
+ }
+ self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep)
+ if not childBlock.hasReturn
+ and not childBlock.hasGoTo
+ and not childBlock.hasBreak then
+ local finalStep = {
+ type = 'save',
+ pos = childBlock.finish,
+ tag = 'final #' .. i,
+ }
+ finals[#finals+1] = finalStep
+ self.steps[#self.steps+1] = finalStep
+ end
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'reset child',
+ ref1 = outStep,
+ pos = childBlock.finish,
+ }
+ end
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'reset if',
+ pos = block.finish,
+ copy = true,
+ }
+ for _, final in ipairs(finals) do
+ self.steps[#self.steps+1] = {
+ type = 'merge',
+ ref2 = final,
+ pos = block.finish,
+ }
+ end
+ end
+
+ if block.type == 'function'
+ or block.type == 'while'
+ or block.type == 'loop'
+ or block.type == 'in'
+ or block.type == 'repeat'
+ or block.type == 'for' then
+ local savePoint = {
+ type = 'save',
+ copy = true,
+ pos = block.start,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ copy = true,
+ pos = block.start,
+ }
+ self.steps[#self.steps+1] = savePoint
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ pos = block.finish,
+ ref1 = savePoint,
+ }
+ end
+end
+
+---@return parser.object[]
+function mt:_getCasts()
+ local root = guide.getRoot(self.loc)
+ if not root._casts then
+ root._casts = {}
+ local docs = root.docs
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.cast' and doc.loc then
+ root._casts[#root._casts+1] = doc
+ end
+ end
+ end
+ return root._casts
+end
+
+function mt:_preCompile()
+ local startPos = self.loc.start
+ local finishPos = 0
+
+ for _, ref in ipairs(self.loc.ref) do
+ self.steps[#self.steps+1] = {
+ type = 'object',
+ object = ref,
+ pos = ref.range or ref.start,
+ }
+ if ref.start > finishPos then
+ finishPos = ref.start
+ end
+ local block = guide.getParentBlock(ref)
+ self:_compileBlock(block)
+ end
+
+ for i, step in ipairs(self.steps) do
+ if step.type ~= 'object' then
+ step.order = i
+ end
+ end
+
+ local casts = self:_getCasts()
+ for _, cast in ipairs(casts) do
+ if cast.loc[1] == self.loc[1]
+ and cast.start > startPos
+ and cast.finish < finishPos
+ and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'cast',
+ cast = cast,
+ pos = cast.start,
+ }
+ end
+ end
+
+ table.sort(self.steps, function (a, b)
+ if a.pos == b.pos then
+ return (a.order or 0) < (b.order or 0)
+ else
+ return a.pos < b.pos
+ end
+ end)
+end
+
+---@param loc parser.object
+---@param node vm.node
+---@return vm.node
+local function checkAssert(loc, node)
+ local parent = loc.parent
+ if parent.type == 'binary' then
+ if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then
+ local exp
+ for i = 1, 2 do
+ if parent[i] == loc then
+ exp = parent[i % 2 + 1]
+ end
+ end
+ if exp and guide.isLiteral(exp) then
+ local callargs = parent.parent
+ if callargs.type == 'callargs'
+ and callargs.parent.node.special == 'assert'
+ and callargs[1] == parent then
+ if parent.op.type == '~=' then
+ node:remove(exp.type)
+ end
+ if parent.op.type == '==' then
+ node = vm.compileNode(exp)
+ end
+ end
+ end
+ end
+ end
+ if parent.type == 'callargs'
+ and parent.parent.node.special == 'assert'
+ and parent[1] == loc then
+ node:setTruthy()
+ end
+ return node
+end
+
+---@param callback fun(src: parser.object, node: vm.node)
+function mt:launch(callback)
+ local topNode = vm.getNode(self.loc):copy()
+ for _, step in ipairs(self.steps) do
+ local node = step.ref1 and step.ref1.node or topNode
+ if step.type == 'truthy' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:setTruthy()
+ elseif step.type == 'falsy' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:setFalsy()
+ elseif step.type == 'as' then
+ if step.new then
+ topNode = vm.createNode(vm.getGlobal('type', step.name))
+ else
+ node:clear()
+ node:merge(vm.getGlobal('type', step.name))
+ end
+ elseif step.type == 'add' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:merge(vm.getGlobal('type', step.name))
+ elseif step.type == 'remove' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:remove(step.name)
+ elseif step.type == 'object' then
+ topNode = callback(step.object, node) or node
+ if step.object.type == 'getlocal' then
+ topNode = checkAssert(step.object, node)
+ end
+ elseif step.type == 'save' then
+ if step.copy then
+ node = node:copy()
+ end
+ step.node = node
+ elseif step.type == 'push' then
+ if step.copy then
+ node = node:copy()
+ end
+ topNode = node
+ elseif step.type == 'merge' then
+ node:merge(step.ref2.node)
+ elseif step.type == 'cast' then
+ topNode = node:copy()
+ for _, cast in ipairs(step.cast.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ topNode:addOptional()
+ end
+ if cast.extends then
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ topNode:removeOptional()
+ end
+ if cast.extends then
+ topNode:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ topNode:clear()
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ end
+ end
+ end
+ end
+end
+
+---@param loc parser.object
+---@return vm.runner
+function vm.launchRunner(loc)
+ local self = setmetatable({
+ loc = loc,
+ mainBlock = guide.getParentBlock(loc),
+ blocks = {},
+ steps = {},
+ }, mt)
+
+ self:_preCompile()
+
+ return self
+end
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 9fe0f172..a9c38a87 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -2,257 +2,19 @@
local vm = require 'vm.vm'
local guide = require 'parser.guide'
+---@alias vm.runner.callback fun(src: parser.object, node: vm.node)
+
---@class vm.runner
----@field loc parser.object
----@field mainBlock parser.object
----@field blocks table<parser.object, true>
----@field steps vm.runner.step[]
+---@field _loc parser.object
+---@field _objs parser.object[]
+---@field _callback vm.runner.callback
local mt = {}
mt.__index = mt
-mt.index = 1
-
----@class parser.object
----@field _casts parser.object[]
-
----@class vm.runner.step
----@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast'
----@field pos integer
----@field order? integer
----@field node? vm.node
----@field object? parser.object
----@field name? string
----@field cast? parser.object
----@field tag? string
----@field copy? boolean
----@field new? boolean
----@field ref1? vm.runner.step
----@field ref2? vm.runner.step
-
----@param filter parser.object
----@param outStep vm.runner.step
----@param blockStep vm.runner.step
-function mt:_compileNarrowByFilter(filter, outStep, blockStep)
- if not filter then
- return
- end
- if filter.type == 'paren' then
- if filter.exp then
- self:_compileNarrowByFilter(filter.exp, outStep, blockStep)
- end
- return
- end
- if filter.type == 'unary' then
- if not filter.op
- or not filter[1] then
- return
- end
- if filter.op.type == 'not' then
- local exp = filter[1]
- if exp.type == 'getlocal' and exp.node == self.loc then
- self.steps[#self.steps+1] = {
- type = 'falsy',
- pos = filter.finish,
- new = true,
- }
- self.steps[#self.steps+1] = {
- type = 'truthy',
- pos = filter.finish,
- ref1 = outStep,
- }
- end
- end
- elseif filter.type == 'binary' then
- if not filter.op
- or not filter[1]
- or not filter[2] then
- return
- end
- if filter.op.type == 'and' then
- local dummyStep = {
- type = 'save',
- copy = true,
- ref1 = outStep,
- pos = filter.start - 1,
- }
- self.steps[#self.steps+1] = dummyStep
- self:_compileNarrowByFilter(filter[1], dummyStep, blockStep)
- self:_compileNarrowByFilter(filter[2], dummyStep, blockStep)
- end
- if filter.op.type == 'or' then
- self:_compileNarrowByFilter(filter[1], outStep, blockStep)
- local dummyStep = {
- type = 'push',
- copy = true,
- ref1 = outStep,
- pos = filter.op.finish,
- }
- self.steps[#self.steps+1] = dummyStep
- self:_compileNarrowByFilter(filter[2], outStep, dummyStep)
- self.steps[#self.steps+1] = {
- type = 'push',
- tag = 'or reset',
- ref1 = blockStep,
- pos = filter.finish,
- }
- end
- if filter.op.type == '=='
- or filter.op.type == '~=' then
- local loc, exp
- for i = 1, 2 do
- loc = filter[i]
- if loc.type == 'getlocal' and loc.node == self.loc then
- exp = filter[i % 2 + 1]
- break
- end
- end
- if not loc or not exp then
- return
- end
- if guide.isLiteral(exp) then
- if filter.op.type == '==' then
- self.steps[#self.steps+1] = {
- type = 'remove',
- name = exp.type,
- pos = filter.finish,
- ref1 = outStep,
- }
- self.steps[#self.steps+1] = {
- type = 'as',
- name = exp.type,
- pos = filter.finish,
- new = true,
- }
- end
- if filter.op.type == '~=' then
- self.steps[#self.steps+1] = {
- type = 'as',
- name = exp.type,
- pos = filter.finish,
- ref1 = outStep,
- }
- self.steps[#self.steps+1] = {
- type = 'remove',
- name = exp.type,
- pos = filter.finish,
- new = true,
- }
- end
- end
- end
- else
- if filter.type == 'getlocal' and filter.node == self.loc then
- self.steps[#self.steps+1] = {
- type = 'truthy',
- pos = filter.finish,
- new = true,
- }
- self.steps[#self.steps+1] = {
- type = 'falsy',
- pos = filter.finish,
- ref1 = outStep,
- }
- end
- end
-end
-
----@param block parser.object
-function mt:_compileBlock(block)
- if self.blocks[block] then
- return
- end
- self.blocks[block] = true
- if block == self.mainBlock then
- return
- end
-
- local parentBlock = guide.getParentBlock(block)
- self:_compileBlock(parentBlock)
-
- if block.type == 'if' then
- ---@type vm.runner.step[]
- local finals = {}
- for i, childBlock in ipairs(block) do
- local blockStep = {
- type = 'save',
- tag = 'block',
- copy = true,
- pos = childBlock.start,
- }
- local outStep = {
- type = 'save',
- tag = 'out',
- copy = true,
- pos = childBlock.start,
- }
- self.steps[#self.steps+1] = blockStep
- self.steps[#self.steps+1] = outStep
- self.steps[#self.steps+1] = {
- type = 'push',
- ref1 = blockStep,
- pos = childBlock.start,
- }
- self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep)
- if not childBlock.hasReturn
- and not childBlock.hasGoTo
- and not childBlock.hasBreak then
- local finalStep = {
- type = 'save',
- pos = childBlock.finish,
- tag = 'final #' .. i,
- }
- finals[#finals+1] = finalStep
- self.steps[#self.steps+1] = finalStep
- end
- self.steps[#self.steps+1] = {
- type = 'push',
- tag = 'reset child',
- ref1 = outStep,
- pos = childBlock.finish,
- }
- end
- self.steps[#self.steps+1] = {
- type = 'push',
- tag = 'reset if',
- pos = block.finish,
- copy = true,
- }
- for _, final in ipairs(finals) do
- self.steps[#self.steps+1] = {
- type = 'merge',
- ref2 = final,
- pos = block.finish,
- }
- end
- end
-
- if block.type == 'function'
- or block.type == 'while'
- or block.type == 'loop'
- or block.type == 'in'
- or block.type == 'repeat'
- or block.type == 'for' then
- local savePoint = {
- type = 'save',
- copy = true,
- pos = block.start,
- }
- self.steps[#self.steps+1] = {
- type = 'push',
- copy = true,
- pos = block.start,
- }
- self.steps[#self.steps+1] = savePoint
- self.steps[#self.steps+1] = {
- type = 'push',
- pos = block.finish,
- ref1 = savePoint,
- }
- end
-end
+mt._index = 1
---@return parser.object[]
function mt:_getCasts()
- local root = guide.getRoot(self.loc)
+ local root = guide.getRoot(self._loc)
if not root._casts then
root._casts = {}
local docs = root.docs
@@ -265,180 +27,256 @@ function mt:_getCasts()
return root._casts
end
-function mt:_preCompile()
- local startPos = self.loc.start
+function mt:_collect()
+ local startPos = self._loc.start
local finishPos = 0
- for _, ref in ipairs(self.loc.ref) do
- self.steps[#self.steps+1] = {
- type = 'object',
- object = ref,
- pos = ref.range or ref.start,
- }
- if ref.start > finishPos then
- finishPos = ref.start
+ for _, ref in ipairs(self._loc.ref) do
+ if ref.type == 'getlocal'
+ or ref.type == 'setlocal' then
+ self._objs[#self._objs+1] = ref
+ if ref.start > finishPos then
+ finishPos = ref.start
+ end
end
- local block = guide.getParentBlock(ref)
- self:_compileBlock(block)
end
- for i, step in ipairs(self.steps) do
- if step.type ~= 'object' then
- step.order = i
- end
+ if #self._objs == 0 then
+ return
end
local casts = self:_getCasts()
for _, cast in ipairs(casts) do
- if cast.loc[1] == self.loc[1]
+ if cast.loc[1] == self._loc[1]
and cast.start > startPos
and cast.finish < finishPos
- and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then
- self.steps[#self.steps+1] = {
- type = 'cast',
- cast = cast,
- pos = cast.start,
- }
+ and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then
+ self._objs[#self._objs+1] = cast
end
end
- table.sort(self.steps, function (a, b)
- if a.pos == b.pos then
- return (a.order or 0) < (b.order or 0)
- else
- return a.pos < b.pos
- end
+ table.sort(self._objs, function (a, b)
+ return (a.range or a.finish) < (b.range or b.start)
end)
end
----@param loc parser.object
+
+---@param pos integer
---@param node vm.node
---@return vm.node
-local function checkAssert(loc, node)
- local parent = loc.parent
- if parent.type == 'binary' then
- if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then
- local exp
- for i = 1, 2 do
- if parent[i] == loc then
- exp = parent[i % 2 + 1]
- end
- end
- if exp and guide.isLiteral(exp) then
- local callargs = parent.parent
- if callargs.type == 'callargs'
- and callargs.parent.node.special == 'assert'
- and callargs[1] == parent then
- if parent.op.type == '~=' then
- node:remove(exp.type)
- end
- if parent.op.type == '==' then
- node = vm.compileNode(exp)
- end
- end
- end
+---@return parser.object?
+function mt:_fastWard(pos, node)
+ for i = self._index, #self._objs do
+ local obj = self._objs[i]
+ if (obj.range or obj.finish) > pos then
+ self._index = i
+ return node, obj
end
- end
- if parent.type == 'callargs'
- and parent.parent.node.special == 'assert'
- and parent[1] == loc then
- node:setTruthy()
- end
- return node
-end
-
----@param callback fun(src: parser.object, node: vm.node)
-function mt:launch(callback)
- local topNode = vm.getNode(self.loc):copy()
- for _, step in ipairs(self.steps) do
- local node = step.ref1 and step.ref1.node or topNode
- if step.type == 'truthy' then
- if step.new then
- node = node:copy()
- topNode = node
- end
- node:setTruthy()
- elseif step.type == 'falsy' then
- if step.new then
- node = node:copy()
- topNode = node
- end
- node:setFalsy()
- elseif step.type == 'as' then
- if step.new then
- topNode = vm.createNode(vm.getGlobal('type', step.name))
- else
- node:clear()
- node:merge(vm.getGlobal('type', step.name))
- end
- elseif step.type == 'add' then
- if step.new then
- node = node:copy()
- topNode = node
- end
- node:merge(vm.getGlobal('type', step.name))
- elseif step.type == 'remove' then
- if step.new then
- node = node:copy()
- topNode = node
- end
- node:remove(step.name)
- elseif step.type == 'object' then
- topNode = callback(step.object, node) or node
- if step.object.type == 'getlocal' then
- topNode = checkAssert(step.object, node)
- end
- elseif step.type == 'save' then
- if step.copy then
- node = node:copy()
- end
- step.node = node
- elseif step.type == 'push' then
- if step.copy then
- node = node:copy()
+ if obj.type == 'getlocal' then
+ self._callback(obj, node)
+ elseif obj.type == 'setlocal' then
+ local newNode = self._callback(obj, node)
+ if newNode then
+ node = newNode:copy()
end
- topNode = node
- elseif step.type == 'merge' then
- node:merge(step.ref2.node)
- elseif step.type == 'cast' then
- topNode = node:copy()
- for _, cast in ipairs(step.cast.casts) do
+ elseif obj.type == 'doc.cast' then
+ node = node:copy()
+ for _, cast in ipairs(obj.casts) do
if cast.mode == '+' then
if cast.optional then
- topNode:addOptional()
+ node:addOptional()
end
if cast.extends then
- topNode:merge(vm.compileNode(cast.extends))
+ node:merge(vm.compileNode(cast.extends))
end
elseif cast.mode == '-' then
if cast.optional then
- topNode:removeOptional()
+ node:removeOptional()
end
if cast.extends then
- topNode:removeNode(vm.compileNode(cast.extends))
+ node:removeNode(vm.compileNode(cast.extends))
end
else
if cast.extends then
- topNode:clear()
- topNode:merge(vm.compileNode(cast.extends))
+ node:clear()
+ node:merge(vm.compileNode(cast.extends))
end
end
end
end
end
+ self._index = #self._objs + 1
+ return node, nil
+end
+
+---@param action parser.object
+---@param topNode vm.node
+---@param outNode? vm.node
+---@return vm.node
+function mt:_lookInto(action, topNode, outNode)
+ local set
+ local value = vm.getObjectValue(action)
+ if value then
+ set = action
+ action = value
+ end
+ if action.type == 'function'
+ or action.type == 'loop'
+ or action.type == 'in'
+ or action.type == 'repeat'
+ or action.type == 'for' then
+ self:_launchBlock(action, topNode:copy())
+ elseif action.type == 'while' then
+ local blockNode, mainNode = self:_lookInto(action.filter, topNode:copy(), topNode:copy())
+ self:_fastWard(action.filter.finish, blockNode)
+ self:_launchBlock(action, blockNode:copy())
+ topNode = mainNode
+ elseif action.type == 'if' then
+ local hasElse
+ local mainNode = topNode:copy()
+ local blockNodes = {}
+ for _, subBlock in ipairs(action) do
+ local blockNode = mainNode:copy()
+ if subBlock.filter then
+ blockNode, mainNode = self:_lookInto(subBlock.filter, blockNode, mainNode)
+ self:_fastWard(subBlock.filter.finish, blockNode)
+ else
+ hasElse = true
+ mainNode:clear()
+ end
+ blockNode = self:_launchBlock(subBlock, blockNode:copy())
+ local neverReturn = subBlock.hasReturn
+ or subBlock.hasGoTo
+ or subBlock.hasBreak
+ if not neverReturn then
+ blockNodes[#blockNodes+1] = blockNode
+ end
+ end
+ if not hasElse and not topNode:hasKnownType() then
+ mainNode:merge(vm.declareGlobal('type', 'unknown'))
+ end
+ for _, blockNode in ipairs(blockNodes) do
+ mainNode:merge(blockNode)
+ end
+ topNode = mainNode
+ elseif action.type == 'getlocal' then
+ if action.node == self._loc then
+ topNode = self:_fastWard(action.finish, topNode)
+ topNode = topNode:copy():setTruthy()
+ if outNode then
+ outNode:setFalsy()
+ end
+ end
+ elseif action.type == 'unary' then
+ if not action[1] then
+ goto RETURN
+ end
+ if action.op.type == 'not' then
+ outNode = outNode or topNode:copy()
+ outNode, topNode = self:_lookInto(action[1], topNode, outNode)
+ end
+ elseif action.type == 'binary' then
+ if not action[1] or not action[2] then
+ goto RETURN
+ end
+ if action.op.type == 'and' then
+ topNode = self:_lookInto(action[1], topNode)
+ topNode = self:_lookInto(action[2], topNode)
+ elseif action.op.type == 'or' then
+ outNode = outNode or topNode:copy()
+ local topNode1, outNode1 = self:_lookInto(action[1], topNode, outNode)
+ local topNode2, outNode2 = self:_lookInto(action[2], outNode1, outNode1:copy())
+ topNode = vm.createNode(topNode1, topNode2)
+ outNode = outNode2
+ elseif action.op.type == '=='
+ or action.op.type == '~=' then
+ local loc, checker
+ for i = 1, 2 do
+ if action[i].type == 'getlocal' and action[i].node == self._loc then
+ loc = action[i]
+ checker = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
+ elseif action[2].type == 'getlocal' and action[2].node == self._loc then
+ loc = action[3-i]
+ checker = action[i]
+ end
+ end
+ if loc then
+ self:_fastWard(loc.finish, topNode)
+ if guide.isLiteral(checker) then
+ local checkerNode = vm.compileNode(checker)
+ if action.op.type == '==' then
+ topNode = checkerNode
+ if outNode then
+ outNode:removeNode(topNode)
+ end
+ else
+ topNode:removeNode(checkerNode)
+ if outNode then
+ outNode = checkerNode
+ end
+ end
+ end
+ end
+ end
+ elseif action.type == 'call' then
+ if action.node.special == 'assert' and action.args and action.args[1] then
+ topNode = self:_lookInto(action.args[1], topNode)
+ elseif action.args then
+ for _, arg in ipairs(action.args) do
+ self:_lookInto(arg, topNode)
+ end
+ end
+ elseif action.type == 'return' then
+ for _, rtn in ipairs(action) do
+ self:_lookInto(rtn, topNode)
+ end
+ end
+ ::RETURN::
+ topNode = self:_fastWard(action.finish, topNode)
+ if set then
+ topNode = self:_fastWard(set.range or set.finish, topNode)
+ end
+ return topNode, outNode
+end
+
+---@param block parser.object
+---@param node vm.node
+---@return vm.node
+function mt:_launchBlock(block, node)
+ local topNode, top = self:_fastWard(block.start, node)
+ if not top then
+ return topNode
+ end
+ for _, action in ipairs(block) do
+ if (action.range or action.finish) < (top.range or top.finish) then
+ goto CONTINUE
+ end
+ topNode = self:_lookInto(action, topNode)
+ topNode, top = self:_fastWard(action.range or action.finish, topNode)
+ if not top then
+ return topNode
+ end
+ ::CONTINUE::
+ end
+ -- `x = function () end`: don't touch `x` in the end of function
+ topNode = self:_fastWard(block.finish - 1, topNode)
+ return topNode
end
---@param loc parser.object
----@return vm.runner
-function vm.createRunner(loc)
+---@param callback vm.runner.callback
+function vm.launchRunner(loc, callback)
local self = setmetatable({
- loc = loc,
- mainBlock = guide.getParentBlock(loc),
- blocks = {},
- steps = {},
+ _loc = loc,
+ _objs = {},
+ _callback = callback,
}, mt)
- self:_preCompile()
+ self:_collect()
+
+ if #self._objs == 0 then
+ return
+ end
- return self
+ self:_launchBlock(guide.getParentBlock(loc), vm.getNode(loc):copy())
end
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index fe112bc2..5c17aa3d 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -111,7 +111,7 @@ function mt:resolve(uri, args, removeGeneric)
goto CONTINUE
end
end
- local view = vm.viewObject(obj)
+ local view = vm.viewObject(obj, uri)
if view then
knownTypes[view] = true
end
@@ -130,7 +130,7 @@ function mt:resolve(uri, args, removeGeneric)
if argNode:hasFalsy() then
goto CONTINUE
end
- local view = vm.viewObject(n)
+ local view = vm.viewObject(n, uri)
if knownTypes[view] then
goto CONTINUE
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 566e847a..efcc5fea 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1813,6 +1813,17 @@ TEST 'integer' [[
---@type integer?
local x
+if not x then
+ return
+end
+
+print(<?x?>)
+]]
+
+TEST 'integer' [[
+---@type integer?
+local x
+
if xxx and x then
print(<?x?>)
end
@@ -2153,6 +2164,15 @@ print(<?x?>)
]]
TEST 'integer' [[
+---@type integer?
+local x
+
+while x do
+ print(<?x?>)
+end
+]]
+
+TEST 'integer' [[
---@type fun():integer?
local iter
@@ -2237,6 +2257,19 @@ local x
print(<?x?>)
]]
+TEST 'unknown?' [[
+---@type string?
+local x
+
+if x then
+ return
+else
+ print(<?x?>)
+end
+
+print(x)
+]]
+
TEST 'string' [[
---@type string?
local x
@@ -2298,6 +2331,18 @@ end
print(<?t?>)
]]
+TEST 'unknown?' [[
+---@type integer?
+local t
+
+if t then
+else
+ print(<?t?>)
+end
+
+print(t)
+]]
+
TEST 'table|unknown' [[
local function f()
if x then
@@ -2375,3 +2420,74 @@ TEST '`1`|`true`' [[
---@type `1` | `true`
local <?x?>
]]
+
+TEST 'function' [[
+local x
+
+function x() end
+
+print(<?x?>)
+]]
+
+TEST 'unknown' [[
+local x
+
+if x.field == 'haha' then
+ print(<?x?>)
+end
+]]
+
+TEST 'string' [[
+---@type string?
+local t
+
+if not t or xxx then
+ return
+end
+
+print(<?t?>)
+]]
+
+TEST 'table' [[
+---@type table|nil
+local t
+
+return function ()
+ if not t then
+ return
+ end
+
+ print(<?t?>)
+end
+]]
+
+TEST 'table' [[
+---@type table|nil
+local t
+
+f(function ()
+ if not t then
+ return
+ end
+
+ print(<?t?>)
+end)
+]]
+
+TEST 'table' [[
+---@type table?
+local t
+
+t = t or {}
+
+print(<?t?>)
+]]
+
+TEST 'unknown|nil' [[
+local x
+
+if x == nil then
+end
+
+print(<?x?>)
+]]