summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/vm/compiler.lua55
-rw-r--r--test/type_inference/init.lua12
2 files changed, 44 insertions, 23 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index f16cadd8..446c357e 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -836,10 +836,13 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
for i = fixIndex + 1, myIndex - 1 do
args[#args+1] = call.args[i]
end
- fn = generic:resolve(guide.getUri(call), args)
+ local resolvedNode = generic:resolve(guide.getUri(call), args)
+ vm.setNode(arg, resolvedNode)
+ goto CONTINUE
end
end
vm.setNode(arg, fn)
+ ::CONTINUE::
end
end
end
@@ -907,9 +910,10 @@ end
---@param source parser.object
---@param target parser.object
+---@return boolean
local function compileForVars(source, target)
if not source.exps then
- return
+ return false
end
-- for k, v in pairs(t) do
--> for k, v in iterator, status, initValue do
@@ -940,9 +944,11 @@ local function compileForVars(source, target)
local node = getReturn(source._iterator, i, source._iterArgs)
node:removeOptional()
vm.setNode(loc, node)
+ return true
end
end
end
+ return false
end
---@param source parser.object
@@ -972,27 +978,6 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(source.value))
end
end
- if not hasMarkDoc
- and not hasMarkValue
- and source.ref then
- local firstSet
- local myFunction = guide.getParentFunction(source)
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal' then
- firstSet = ref
- break
- end
- if ref.type == 'getlocal' then
- if guide.getParentFunction(ref) == myFunction then
- break
- end
- end
- end
- if firstSet
- and guide.getBlock(firstSet) == guide.getBlock(source) then
- vm.setNode(source, vm.compileNode(firstSet))
- end
- end
-- function x.y(self, ...) --> function x:y(...)
if source[1] == 'self'
and not hasMarkDoc
@@ -1031,6 +1016,7 @@ local function compileLocal(source)
-- for x in ... do
if source.parent.type == 'in' then
compileForVars(source.parent, source)
+ hasMarkDoc = true
end
-- for x = ... do
@@ -1040,6 +1026,29 @@ local function compileLocal(source)
return
end
vm.setNode(source, vm.declareGlobal('type', 'integer'))
+ hasMarkDoc = true
+ end
+ end
+
+ if not hasMarkDoc
+ and not hasMarkValue
+ and source.ref then
+ local firstSet
+ local myFunction = guide.getParentFunction(source)
+ for _, ref in ipairs(source.ref) do
+ if ref.type == 'setlocal' then
+ firstSet = ref
+ break
+ end
+ if ref.type == 'getlocal' then
+ if guide.getParentFunction(ref) == myFunction then
+ break
+ end
+ end
+ end
+ if firstSet
+ and guide.getBlock(firstSet) == guide.getBlock(source) then
+ vm.setNode(source, vm.compileNode(firstSet))
end
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 4479e031..0b69a34c 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -4136,3 +4136,15 @@ end
print(<?x?>)
]]
+
+TEST 'number' [[
+---@type number[]
+local xs
+
+---@type fun(x): number?
+local f
+
+for _, <?x?> in ipairs(xs) do
+ x = f(x)
+end
+]]