Skip to content

Commit

Permalink
Implement correct environment.yaml creation
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Nov 22, 2023
1 parent 4341030 commit c135cb6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 21 deletions.
63 changes: 51 additions & 12 deletions conda_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def scan_dir(path: Path, current_depth: int) -> None:
return requirements_files


def filter_platform_selectors(content: str) -> list[str]:
def filter_platform_selectors(content: str) -> list[Platforms]:
"""Filter out lines from a requirements file that don't match the platform."""
# we support a very limited set of selectors that adhere to platform only
# refs:
# https://docs.conda.io/projects/conda-build/en/latest/resources/define-metadata.html#preprocessing-selectors
# https://github.com/conda/conda-lock/blob/3d2bf356e2cf3f7284407423f7032189677ba9be/conda_lock/src_parser/selectors.py

platform_sel = {
platform_sel: dict[Platforms, set[str]] = {
"linux-64": {"linux64", "unix", "linux"},
"linux-aarch64": {"aarch64", "unix", "linux"},
"linux-ppc64le": {"ppc64le", "unix", "linux"},
Expand All @@ -94,7 +94,7 @@ def filter_platform_selectors(content: str) -> list[str]:
}

# Reverse the platform_sel for easy lookup
reverse_platform_sel: dict[str, list[str]] = {}
reverse_platform_sel: dict[str, list[Platforms]] = {}
for key, values in platform_sel.items():
for value in values:
reverse_platform_sel.setdefault(value, []).append(key)
Expand Down Expand Up @@ -232,6 +232,45 @@ def _parse_requirements_and_filter_duplicates(
return _filter_pip_and_conda(requirements_with_comments, pip_or_conda, platform)


class EnvSpec(NamedTuple):
"""A conda environment."""

channels: list[str]
conda: list[str | dict[str, str]]
pip: list[str]


def _prepare_for_conda_environment(
requirements_with_comments: RequirementsWithComments,
) -> EnvSpec:
r = requirements_with_comments
conda: list[str | dict[str, str]] = []
pip: list[str] = []
for dependency, comment in r.conda.items():
platforms = filter_platform_selectors(comment) if comment is not None else []
if platforms:
unique_platforms = {p.split("-", 1)[0] for p in platforms}
dependencies = [
{f"sel({_platform})": dependency} for _platform in unique_platforms
]
conda.extend(dependencies)
else:
conda.append(dependency)

for dependency, comment in r.pip.items():
platforms = filter_platform_selectors(comment) if comment is not None else []
if platforms:
for _platform in platforms:
selector = pep508_selector([_platform])
dep = f"{dependency}; {selector}"
pip.append(dep)
else:
pip.append(dependency)
# Filter out duplicate packages that are both in conda and pip
pip = [p for p in pip if p not in conda]
return EnvSpec(list(r.channels), conda, pip)


def _to_requirements(
combined_deps: RequirementsWithComments,
) -> Requirements:
Expand All @@ -252,7 +291,7 @@ def _to_requirements(
return Requirements(channels, conda, pip)


def parse_requirements(
def parse_requirements_and_filter_duplicates(
paths: Sequence[Path],
*,
verbose: bool = False,
Expand All @@ -270,19 +309,19 @@ def parse_requirements(


def generate_conda_env_file(
dependencies: Requirements, # actually a CommentedMap with CommentedSeq
env_spec: EnvSpec,
output_file: str | None = "environment.yaml",
name: str = "myenv",
*,
verbose: bool = False,
) -> None:
"""Generate a conda environment.yaml file or print to stdout."""
_dependencies = deepcopy(dependencies.conda)
_dependencies.append({"pip": dependencies.pip}) # type: ignore[arg-type]
_dependencies = deepcopy(env_spec.conda)
_dependencies.append({"pip": env_spec.pip}) # type: ignore[arg-type, dict-item]
env_data = CommentedMap(
{
"name": name,
"channels": dependencies.channels,
"channels": env_spec.channels,
"dependencies": _dependencies,
},
)
Expand Down Expand Up @@ -314,7 +353,7 @@ def extract_python_requires(
msg = f"File {filename} not found."
raise FileNotFoundError(msg)
return []
deps = parse_requirements(
deps = parse_requirements_and_filter_duplicates(
[p],
pip_or_conda="pip",
verbose=verbose,
Expand Down Expand Up @@ -429,10 +468,10 @@ def main() -> None: # pragma: no cover
args.depth,
verbose=verbose,
)
combined_deps = parse_requirements(requirements_files, verbose=verbose)

combined_deps = _initial_parse_requirements(requirements_files, verbose=verbose)
env_spec = _prepare_for_conda_environment(combined_deps)
output_file = None if args.stdout else args.output
generate_conda_env_file(combined_deps, output_file, args.name, verbose=verbose)
generate_conda_env_file(env_spec, output_file, args.name, verbose=verbose)
if output_file:
with open(output_file, "r+") as f: # noqa: PTH123
content = f.read()
Expand Down
26 changes: 17 additions & 9 deletions tests/test_conda_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
import yaml

from conda_join import (
Requirements,
EnvSpec,
RequirementsWithComments,
_filter_pip_and_conda,
_filter_unsupported_platforms,
_initial_parse_requirements,
_parse_requirements_and_filter_duplicates,
_prepare_for_conda_environment,
_to_requirements,
detect_platform,
extract_python_requires,
filter_platform_selectors,
generate_conda_env_file,
parse_requirements,
parse_requirements_and_filter_duplicates,
pep508_selector,
scan_requirements,
)
Expand Down Expand Up @@ -87,7 +89,10 @@ def test_parse_requirements(
verbose: bool, # noqa: FBT001
setup_test_files: tuple[Path, Path],
) -> None:
combined_deps = parse_requirements(setup_test_files, verbose=verbose)
combined_deps = parse_requirements_and_filter_duplicates(
setup_test_files,
verbose=verbose,
)
assert "numpy" in combined_deps.conda
assert "mumps" in combined_deps.conda
assert len(combined_deps.conda) == 2 # noqa: PLR2004
Expand All @@ -102,8 +107,10 @@ def test_generate_conda_env_file(
setup_test_files: tuple[Path, Path],
) -> None:
output_file = tmp_path / "environment.yaml"
combined_deps = parse_requirements(setup_test_files, verbose=verbose)
generate_conda_env_file(combined_deps, str(output_file), verbose=verbose)
combined_deps = _initial_parse_requirements(setup_test_files, verbose=verbose)
env_spec = _prepare_for_conda_environment(combined_deps)

generate_conda_env_file(env_spec, str(output_file), verbose=verbose)

with output_file.open() as f:
env_data = yaml.safe_load(f)
Expand All @@ -116,8 +123,9 @@ def test_generate_conda_env_stdout(
setup_test_files: tuple[Path, Path],
capsys: pytest.CaptureFixture,
) -> None:
combined_deps = parse_requirements(setup_test_files, verbose=False)
generate_conda_env_file(combined_deps, None)
combined_deps = _initial_parse_requirements(setup_test_files)
env_spec = _prepare_for_conda_environment(combined_deps)
generate_conda_env_file(env_spec, None)

captured = capsys.readouterr()
assert "dependencies" in captured.out
Expand All @@ -135,13 +143,13 @@ def test_verbose_output(tmp_path: Path, capsys: pytest.CaptureFixture) -> None:
assert "Scanning in" in captured.out
assert str(tmp_path / "dir3") in captured.out

parse_requirements([f], verbose=True)
parse_requirements_and_filter_duplicates([f], verbose=True)
captured = capsys.readouterr()
assert "Parsing" in captured.out
assert str(f) in captured.out

generate_conda_env_file(
Requirements(channels=[], conda=[], pip=[]),
EnvSpec(channels=[], conda=[], pip=[]),
verbose=True,
)
captured = capsys.readouterr()
Expand Down

0 comments on commit c135cb6

Please sign in to comment.