Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ type API struct {
logger logging.Logger
session *project.Session

projects map[Handle[project.Project]]tspath.Path
filesMu sync.Mutex
files handleMap[ast.SourceFile]
symbolsMu sync.Mutex
symbols handleMap[ast.Symbol]
typesMu sync.Mutex
types handleMap[checker.Type]
projectsMu sync.RWMutex
projects map[Handle[project.Project]]tspath.Path
filesMu sync.Mutex
files handleMap[ast.SourceFile]
symbolsMu sync.Mutex
symbols handleMap[ast.Symbol]
typesMu sync.Mutex
types handleMap[checker.Type]
}

func NewAPI(init *APIInit) *API {
Expand Down Expand Up @@ -145,12 +146,18 @@ func (api *API) LoadProject(ctx context.Context, configFileName string) (*Projec
return nil, err
}
data := NewProjectResponse(project)
// Acquire write lock to safely add project to the map
api.projectsMu.Lock()
api.projects[data.Id] = project.ConfigFilePath()
api.projectsMu.Unlock()
return data, nil
}

func (api *API) GetSymbolAtPosition(ctx context.Context, projectId Handle[project.Project], fileName string, position int) (*SymbolResponse, error) {
// Acquire read lock to safely access projects map
api.projectsMu.RLock()
projectPath, ok := api.projects[projectId]
api.projectsMu.RUnlock()
if !ok {
return nil, errors.New("project ID not found")
}
Expand All @@ -174,7 +181,10 @@ func (api *API) GetSymbolAtPosition(ctx context.Context, projectId Handle[projec
}

func (api *API) GetSymbolAtLocation(ctx context.Context, projectId Handle[project.Project], location Handle[ast.Node]) (*SymbolResponse, error) {
// Acquire read lock to safely access projects map
api.projectsMu.RLock()
projectPath, ok := api.projects[projectId]
api.projectsMu.RUnlock()
if !ok {
return nil, errors.New("project ID not found")
}
Expand Down Expand Up @@ -216,7 +226,10 @@ func (api *API) GetSymbolAtLocation(ctx context.Context, projectId Handle[projec
}

func (api *API) GetTypeOfSymbol(ctx context.Context, projectId Handle[project.Project], symbolHandle Handle[ast.Symbol]) (*TypeResponse, error) {
// Acquire read lock to safely access projects map
api.projectsMu.RLock()
projectPath, ok := api.projects[projectId]
api.projectsMu.RUnlock()
if !ok {
return nil, errors.New("project ID not found")
}
Expand All @@ -242,7 +255,10 @@ func (api *API) GetTypeOfSymbol(ctx context.Context, projectId Handle[project.Pr
}

func (api *API) GetSourceFile(projectId Handle[project.Project], fileName string) (*ast.SourceFile, error) {
// Acquire read lock to safely access projects map
api.projectsMu.RLock()
projectPath, ok := api.projects[projectId]
api.projectsMu.RUnlock()
if !ok {
return nil, errors.New("project ID not found")
}
Expand All @@ -267,11 +283,15 @@ func (api *API) releaseHandle(handle string) error {
switch handle[0] {
case handlePrefixProject:
projectId := Handle[project.Project](handle)
// Acquire write lock to safely delete project from the map
api.projectsMu.Lock()
_, ok := api.projects[projectId]
if !ok {
api.projectsMu.Unlock()
return fmt.Errorf("project %q not found", handle)
}
delete(api.projects, projectId)
api.projectsMu.Unlock()
case handlePrefixFile:
fileId := Handle[ast.SourceFile](handle)
api.filesMu.Lock()
Expand Down