package mock

import (
	"github.com/anchore/grype/grype/db/v6/name"
	grypePkg "github.com/anchore/grype/grype/pkg"
	"github.com/anchore/grype/grype/search"
	"github.com/anchore/grype/grype/vulnerability"
)

// VulnerabilityProvider returns a new mock implementation of a vulnerability Provider, with the provided set of vulnerabilities
func VulnerabilityProvider(vulnerabilities ...vulnerability.Vulnerability) vulnerability.Provider {
	return &mockProvider{
		Vulnerabilities: vulnerabilities,
	}
}

type mockProvider struct {
	Vulnerabilities []vulnerability.Vulnerability
}

func (s *mockProvider) Close() error {
	return nil
}

func (s *mockProvider) PackageSearchNames(p grypePkg.Package) []string {
	return name.PackageNames(p)
}

// VulnerabilityMetadata returns the metadata associated with a vulnerability
func (s *mockProvider) VulnerabilityMetadata(ref vulnerability.Reference) (*vulnerability.Metadata, error) {
	for _, vuln := range s.Vulnerabilities {
		if vuln.ID == ref.ID && vuln.Namespace == ref.Namespace {
			var meta *vulnerability.Metadata
			if m, ok := vuln.Internal.(vulnerability.Metadata); ok {
				meta = &m
			}
			if m, ok := vuln.Internal.(*vulnerability.Metadata); ok {
				meta = m
			}
			if meta != nil {
				if meta.ID != vuln.ID {
					meta.ID = vuln.ID
				}
				if meta.Namespace != vuln.Namespace {
					meta.Namespace = vuln.Namespace
				}
				return meta, nil
			}
		}
	}
	return nil, nil
}

func (s *mockProvider) FindVulnerabilities(criteria ...vulnerability.Criteria) ([]vulnerability.Vulnerability, error) {
	if err := search.ValidateCriteria(criteria); err != nil {
		return nil, err
	}

	var out []vulnerability.Vulnerability
	out = append(out, s.Vulnerabilities...)
	return filterE(out, func(v vulnerability.Vulnerability) (bool, error) {
		for _, row := range search.CriteriaIterator(criteria) {
			for _, c := range row {
				matches, _, err := c.MatchesVulnerability(v)
				if !matches || err != nil {
					return false, err
				}
			}
		}
		return true, nil
	})
}

func filterE[T any](out []T, keep func(v T) (bool, error)) ([]T, error) {
	for i := 0; i < len(out); i++ {
		ok, err := keep(out[i])
		if err != nil {
			return nil, err
		}
		if !ok {
			out = append(out[:i], out[i+1:]...)
			i--
		}
	}
	return out, nil
}
