diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2023-11-14 16:49:27 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2023-11-14 16:49:27 +0800 |
commit | 8c1669c85594e245eadfd1dbf1e537349640f2fb (patch) | |
tree | 9bacc4f2b1a57e71a21596deecdf58a355676026 /script | |
parent | ee590a4cd1bc972ffe19e232b176aa1ffaba2d47 (diff) | |
download | lua-language-server-8c1669c85594e245eadfd1dbf1e537349640f2fb.zip |
fix type infer in overload
Diffstat (limited to 'script')
-rw-r--r-- | script/core/completion/completion.lua | 2 | ||||
-rw-r--r-- | script/vm/compiler.lua | 87 | ||||
-rw-r--r-- | script/vm/sign.lua | 2 | ||||
-rw-r--r-- | script/vm/type.lua | 22 |
4 files changed, 76 insertions, 37 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index 4462bf64..b7d4650c 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -1659,7 +1659,7 @@ local function tryCallArg(state, position, results) return end ---@diagnostic disable-next-line: missing-fields - local node = vm.compileCallArg({ type = 'dummyarg' }, call, argIndex) + local node = vm.compileCallArg({ type = 'dummyarg', uri = state.uri }, call, argIndex) if not node then return end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 7e026474..8a1fa96a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -882,52 +882,69 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) end end - for n in callNode:eachObject() do - if n.type == 'function' then - ---@cast n parser.object - local sign = vm.getSign(n) + ---@param n parser.object + local function dealDocFunc(n) + local myEvent + if n.args[eventIndex] then + local argNode = vm.compileNode(n.args[eventIndex]) + myEvent = argNode:get(1) + end + if not myEvent + or not eventMap + or myIndex <= eventIndex + or myEvent.type ~= 'doc.type.string' + or eventMap[myEvent[1]] then local farg = getFuncArg(n, myIndex) if farg then for fn in vm.compileNode(farg):eachObject() do if isValidCallArgNode(arg, fn) then - if fn.type == 'doc.type.function' then - ---@cast fn parser.object - if sign then - local generic = vm.createGeneric(fn, sign) - local args = {} - for i = fixIndex + 1, myIndex - 1 do - args[#args+1] = call.args[i] - end - local resolvedNode = generic:resolve(guide.getUri(call), args) - vm.setNode(arg, resolvedNode) - goto CONTINUE - end - end vm.setNode(arg, fn) - ::CONTINUE:: end end end end - if n.type == 'doc.type.function' then - ---@cast n parser.object - local myEvent - if n.args[eventIndex] then - local argNode = vm.compileNode(n.args[eventIndex]) - myEvent = argNode:get(1) - end - if not myEvent - or not eventMap - or myIndex <= eventIndex - or myEvent.type ~= 'doc.type.string' - or eventMap[myEvent[1]] then - local farg = getFuncArg(n, myIndex) - if farg then - for fn in vm.compileNode(farg):eachObject() do - if isValidCallArgNode(arg, fn) then - vm.setNode(arg, fn) + end + + ---@param n parser.object + local function dealFunction(n) + local sign = vm.getSign(n) + local farg = getFuncArg(n, myIndex) + if farg then + for fn in vm.compileNode(farg):eachObject() do + if isValidCallArgNode(arg, fn) then + if fn.type == 'doc.type.function' then + ---@cast fn parser.object + if sign then + local generic = vm.createGeneric(fn, sign) + local args = {} + for i = fixIndex + 1, myIndex - 1 do + args[#args+1] = call.args[i] + end + local resolvedNode = generic:resolve(guide.getUri(call), args) + vm.setNode(arg, resolvedNode) + goto CONTINUE end end + vm.setNode(arg, fn) + ::CONTINUE:: + end + end + end + end + + for n in callNode:eachObject() do + if n.type == 'function' then + ---@cast n parser.object + dealFunction(n) + elseif n.type == 'doc.type.function' then + ---@cast n parser.object + dealDocFunc(n) + elseif n.type == 'global' and n.cate == 'type' then + ---@cast n vm.global + local overloads = vm.getOverloadsByTypeName(n.name, guide.getUri(arg)) + if overloads then + for _, func in ipairs(overloads) do + dealDocFunc(func) end end end diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 1f434475..3cd6bc5d 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -254,7 +254,7 @@ function mt:resolve(uri, args) local argNode = vm.compileNode(arg) local knownTypes, genericNames = getSignInfo(sign) if not isAllResolved(genericNames) then - local newArgNode = buildArgNode(argNode,sign, knownTypes) + local newArgNode = buildArgNode(argNode, sign, knownTypes) resolve(sign, newArgNode) end end diff --git a/script/vm/type.lua b/script/vm/type.lua index 910d7960..545d2de5 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -767,3 +767,25 @@ function vm.viewTypeErrorMessage(uri, errs) return table.concat(lines, '\n') end end + +---@param name string +---@param uri uri +---@return parser.object[]? +function vm.getOverloadsByTypeName(name, uri) + local global = vm.getGlobal('type', name) + if not global then + return nil + end + local results + for _, set in ipairs(global:getSets(uri)) do + for _, doc in ipairs(set.bindGroup) do + if doc.type == 'doc.overload' then + if not results then + results = {} + end + results[#results+1] = doc.overload + end + end + end + return results +end |