summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/parser/compile.lua45
-rw-r--r--script/parser/guide.lua3
-rw-r--r--script/vm/tracer.lua69
-rw-r--r--test/type_inference/init.lua11
4 files changed, 105 insertions, 23 deletions
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index b8040382..73aef048 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -3136,6 +3136,22 @@ local function parseGoTo()
return action
end
+local function parseFilter()
+ local exp = parseExp()
+ if exp then
+ local filter = {
+ type = 'filter',
+ start = exp.start,
+ finish = exp.finish,
+ exp = exp,
+ }
+ exp.parent = filter
+ return filter
+ else
+ missExp()
+ end
+end
+
local function parseIfBlock(parent)
local ifLeft = getPosition(Tokens[Index], 'left')
local ifRight = getPosition(Tokens[Index] + 1, 'right')
@@ -3151,13 +3167,11 @@ local function parseIfBlock(parent)
}
}
skipSpace()
- local filter = parseExp()
+ local filter = parseFilter()
if filter then
ifblock.filter = filter
ifblock.finish = filter.finish
filter.parent = ifblock
- else
- missExp()
end
skipSpace()
local thenToken = Tokens[Index + 1]
@@ -3210,13 +3224,11 @@ local function parseElseIfBlock(parent)
}
Index = Index + 2
skipSpace()
- local filter = parseExp()
+ local filter = parseFilter()
if filter then
elseifblock.filter = filter
elseifblock.finish = filter.finish
filter.parent = elseifblock
- else
- missExp()
end
skipSpace()
local thenToken = Tokens[Index + 1]
@@ -3524,15 +3536,16 @@ local function parseWhile()
skipSpace()
local nextToken = Tokens[Index + 1]
- local filter = nextToken ~= 'do'
- and nextToken ~= 'then'
- and parseExp()
- if filter then
- action.filter = filter
- action.finish = filter.finish
- filter.parent = action
- else
+ if nextToken == 'do'
+ or nextToken == 'then' then
missExp()
+ else
+ local filter = parseFilter()
+ if filter then
+ action.filter = filter
+ action.finish = filter.finish
+ filter.parent = action
+ end
end
skipSpace()
@@ -3611,12 +3624,10 @@ local function parseRepeat()
Index = Index + 2
skipSpace()
- local filter = parseExp()
+ local filter = parseFilter()
if filter then
action.filter = filter
filter.parent = action
- else
- missExp()
end
else
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 147e6237..c74593a4 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -139,7 +139,8 @@ local childMap = {
['getfield'] = {'node', 'field'},
['list'] = {'#'},
['binary'] = {1, 2},
- ['unary'] = {1},
+ ['unary'] = { 1 },
+ ['filter'] = {'exp'},
['doc'] = {'#'},
['doc.class'] = {'class', '#extends', '#signs', 'comment'},
diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua
index e304bbd1..61ad89ee 100644
--- a/script/vm/tracer.lua
+++ b/script/vm/tracer.lua
@@ -102,21 +102,80 @@ function mt:getLastAssign(block, pos)
return assign
end
+---@param filter parser.object
+---@param node vm.node?
+---@return vm.node
+function mt:narrowByFilter(filter, node)
+ if not node then
+ node = vm.createNode()
+ end
+ if filter.type == 'filter' then
+ node = self:narrowByFilter(filter.exp, node)
+ return node
+ end
+ if filter.type == 'getlocal' then
+ if filter.node == self.source then
+ node = node:copy()
+ node:removeOptional()
+ end
+ return node
+ end
+ return node
+end
+
---@param source parser.object
---@return vm.node?
function mt:calcNode(source)
+ if source.type == 'getlocal' then
+ return nil
+ end
+ if source.type == 'local' then
+ if source ~= self.source then
+ return nil
+ end
+ end
+ if source.type == 'setlocal' then
+ if source.node ~= self.source then
+ return nil
+ end
+ end
if guide.isSet(source) then
local node = vm.compileNode(source)
return node
end
if source.type == 'do' then
local lastAssign = self:getLastAssign(source, source.finish)
- if lastAssign then
- return self:getNode(lastAssign)
- else
- return nil
+ return self:getNode(lastAssign or source.parent)
+ end
+ if source.type == 'ifblock' then
+ local currentNode = self:getNode(source.parent)
+ local narrowedNode = self:narrowByFilter(source.filter, currentNode)
+ return narrowedNode
+ end
+ if source.type == 'filter' then
+ local parent = source.parent
+ ---@type parser.object
+ local outBlock
+ if parent.type == 'ifblock' then
+ outBlock = parent.parent.parent
+ local lastAssign = self:getLastAssign(outBlock, parent.start)
+ return self:getNode(lastAssign or source.parent)
+ elseif parent.type == 'elseifblock' then
+ outBlock = parent.parent.parent
+ local lastAssign = self:getLastAssign(outBlock, parent.start)
+ return self:getNode(lastAssign or source.parent)
+ elseif parent.type == 'while' then
+ outBlock = parent.parent
+ local lastAssign = self:getLastAssign(outBlock, parent.start)
+ return self:getNode(lastAssign or source.parent)
+ elseif parent.type == 'repeat' then
+ outBlock = parent.parent
+ local lastAssign = self:getLastAssign(outBlock, parent.start)
+ return self:getNode(lastAssign or source.parent)
end
+ assert(outBlock, parent.type)
end
+ return nil
end
---@param source parser.object
@@ -140,7 +199,7 @@ function mt:getNode(source)
return node
end
local lastAssign = self:getLastAssign(parentBlock, source.start)
- local parentNode = self:getNode(lastAssign or parentBlock)
+ local parentNode = self:getNode(lastAssign or source.parent)
self.nodes[source] = parentNode or false
return parentNode
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index cabff606..eb2d2e5c 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1829,6 +1829,17 @@ end
print(<?x?>)
]]
+TEST 'nil' [[
+---@type integer?
+local x
+
+if not x then
+ print(<?x?>)
+end
+
+print(x)
+]]
+
TEST 'integer' [[
---@type integer?
local x