summaryrefslogtreecommitdiff
path: root/script/vm
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2023-11-14 16:49:27 +0800
committer最萌小汐 <sumneko@hotmail.com>2023-11-14 16:49:27 +0800
commit8c1669c85594e245eadfd1dbf1e537349640f2fb (patch)
tree9bacc4f2b1a57e71a21596deecdf58a355676026 /script/vm
parentee590a4cd1bc972ffe19e232b176aa1ffaba2d47 (diff)
downloadlua-language-server-8c1669c85594e245eadfd1dbf1e537349640f2fb.zip
fix type infer in overload
Diffstat (limited to 'script/vm')
-rw-r--r--script/vm/compiler.lua87
-rw-r--r--script/vm/sign.lua2
-rw-r--r--script/vm/type.lua22
3 files changed, 75 insertions, 36 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 7e026474..8a1fa96a 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -882,52 +882,69 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
end
end
- for n in callNode:eachObject() do
- if n.type == 'function' then
- ---@cast n parser.object
- local sign = vm.getSign(n)
+ ---@param n parser.object
+ local function dealDocFunc(n)
+ local myEvent
+ if n.args[eventIndex] then
+ local argNode = vm.compileNode(n.args[eventIndex])
+ myEvent = argNode:get(1)
+ end
+ if not myEvent
+ or not eventMap
+ or myIndex <= eventIndex
+ or myEvent.type ~= 'doc.type.string'
+ or eventMap[myEvent[1]] then
local farg = getFuncArg(n, myIndex)
if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
- if fn.type == 'doc.type.function' then
- ---@cast fn parser.object
- if sign then
- local generic = vm.createGeneric(fn, sign)
- local args = {}
- for i = fixIndex + 1, myIndex - 1 do
- args[#args+1] = call.args[i]
- end
- local resolvedNode = generic:resolve(guide.getUri(call), args)
- vm.setNode(arg, resolvedNode)
- goto CONTINUE
- end
- end
vm.setNode(arg, fn)
- ::CONTINUE::
end
end
end
end
- if n.type == 'doc.type.function' then
- ---@cast n parser.object
- local myEvent
- if n.args[eventIndex] then
- local argNode = vm.compileNode(n.args[eventIndex])
- myEvent = argNode:get(1)
- end
- if not myEvent
- or not eventMap
- or myIndex <= eventIndex
- or myEvent.type ~= 'doc.type.string'
- or eventMap[myEvent[1]] then
- local farg = getFuncArg(n, myIndex)
- if farg then
- for fn in vm.compileNode(farg):eachObject() do
- if isValidCallArgNode(arg, fn) then
- vm.setNode(arg, fn)
+ end
+
+ ---@param n parser.object
+ local function dealFunction(n)
+ local sign = vm.getSign(n)
+ local farg = getFuncArg(n, myIndex)
+ if farg then
+ for fn in vm.compileNode(farg):eachObject() do
+ if isValidCallArgNode(arg, fn) then
+ if fn.type == 'doc.type.function' then
+ ---@cast fn parser.object
+ if sign then
+ local generic = vm.createGeneric(fn, sign)
+ local args = {}
+ for i = fixIndex + 1, myIndex - 1 do
+ args[#args+1] = call.args[i]
+ end
+ local resolvedNode = generic:resolve(guide.getUri(call), args)
+ vm.setNode(arg, resolvedNode)
+ goto CONTINUE
end
end
+ vm.setNode(arg, fn)
+ ::CONTINUE::
+ end
+ end
+ end
+ end
+
+ for n in callNode:eachObject() do
+ if n.type == 'function' then
+ ---@cast n parser.object
+ dealFunction(n)
+ elseif n.type == 'doc.type.function' then
+ ---@cast n parser.object
+ dealDocFunc(n)
+ elseif n.type == 'global' and n.cate == 'type' then
+ ---@cast n vm.global
+ local overloads = vm.getOverloadsByTypeName(n.name, guide.getUri(arg))
+ if overloads then
+ for _, func in ipairs(overloads) do
+ dealDocFunc(func)
end
end
end
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 1f434475..3cd6bc5d 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -254,7 +254,7 @@ function mt:resolve(uri, args)
local argNode = vm.compileNode(arg)
local knownTypes, genericNames = getSignInfo(sign)
if not isAllResolved(genericNames) then
- local newArgNode = buildArgNode(argNode,sign, knownTypes)
+ local newArgNode = buildArgNode(argNode, sign, knownTypes)
resolve(sign, newArgNode)
end
end
diff --git a/script/vm/type.lua b/script/vm/type.lua
index 910d7960..545d2de5 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -767,3 +767,25 @@ function vm.viewTypeErrorMessage(uri, errs)
return table.concat(lines, '\n')
end
end
+
+---@param name string
+---@param uri uri
+---@return parser.object[]?
+function vm.getOverloadsByTypeName(name, uri)
+ local global = vm.getGlobal('type', name)
+ if not global then
+ return nil
+ end
+ local results
+ for _, set in ipairs(global:getSets(uri)) do
+ for _, doc in ipairs(set.bindGroup) do
+ if doc.type == 'doc.overload' then
+ if not results then
+ results = {}
+ end
+ results[#results+1] = doc.overload
+ end
+ end
+ end
+ return results
+end