diff --git a/augur/api/routes/user.py b/augur/api/routes/user.py index 51220fcf12..dfaeb81f7f 100644 --- a/augur/api/routes/user.py +++ b/augur/api/routes/user.py @@ -2,28 +2,19 @@ """ Creates routes for user functionality """ +from augur.api.routes import AUGUR_API_VERSION import logging -import requests -import os -import base64 -import time import secrets -import pandas as pd -from flask import request, Response, jsonify, session +from flask import request, jsonify, session from flask_login import login_user, logout_user, current_user, login_required from werkzeug.security import check_password_hash -from sqlalchemy.sql import text from sqlalchemy.orm import sessionmaker -from sqlalchemy.orm.exc import NoResultFound from augur.application.db.session import DatabaseSession -from augur.tasks.github.util.github_task_session import GithubTaskSession -from augur.util.repo_load_controller import RepoLoadController from augur.api.util import api_key_required from augur.api.util import ssl_required -from augur.application.db.models import User, UserRepo, UserGroup, UserSessionToken, ClientApplication, RefreshToken -from augur.application.config import get_development_flag +from augur.application.db.models import User, UserSessionToken, RefreshToken from augur.tasks.init.redis_connection import redis_connection as redis from ..server import app, engine @@ -31,9 +22,6 @@ current_user: User = current_user Session = sessionmaker(bind=engine) -from augur.api.routes import AUGUR_API_VERSION - - @app.route(f"/{AUGUR_API_VERSION}/user/validate", methods=['POST']) @ssl_required def validate_user(): @@ -51,7 +39,7 @@ def validate_user(): return jsonify({"status": "Invalid username"}) checkPassword = check_password_hash(user.login_hashword, password) - if checkPassword == False: + if not checkPassword: return jsonify({"status": "Invalid password"}) @@ -89,9 +77,9 @@ def generate_session(application): code = request.args.get("code") or request.form.get("code") if not code: return jsonify({"status": "Missing argument: code"}), 400 - + grant_type = request.args.get("grant_type") or request.form.get("grant_type") - + if "code" not in grant_type: return jsonify({"status": "Invalid grant type"}) @@ -131,7 +119,7 @@ def refresh_session(application): if not refresh_token_str: return jsonify({"status": "Missing argument: refresh_token"}), 400 - + if request.args.get("grant_type") != "refresh_token": return jsonify({"status": "Invalid grant type"}) @@ -139,17 +127,17 @@ def refresh_session(application): refresh_token = session.query(RefreshToken).filter(RefreshToken.id == refresh_token_str).first() if not refresh_token: - return jsonify({"status": "Invalid refresh token"}) + return jsonify({"status": "Invalid refresh token"}), 400 if refresh_token.user_session.application != application: - return jsonify({"status": "Invalid application"}) + return jsonify({"status": "Invalid application"}), 400 user_session = refresh_token.user_session user = user_session.user new_user_session_token = UserSessionToken.create(session, user.user_id, user_session.application.id).token new_refresh_token_id = RefreshToken.create(session, new_user_session_token).id - + session.delete(refresh_token) session.delete(user_session) session.commit() @@ -327,11 +315,11 @@ def group_repos(): result_dict = result[1] if result[0] is not None: - + for repo in result[0]: repo["base64_url"] = str(repo["base64_url"].decode()) - result_dict.update({"repos": result[0]}) + result_dict.update({"repos": result[0]}) return jsonify(result_dict) @@ -436,7 +424,7 @@ def toggle_user_group_favorite(): Returns ------- dict - A dictionairy with key of 'status' that indicates the success or failure of the operation + A dictionary with key of 'status' that indicates the success or failure of the operation """ group_name = request.args.get("group_name") diff --git a/augur/application/db/models/augur_operations.py b/augur/application/db/models/augur_operations.py index e2100196dc..a2e3a6c4d8 100644 --- a/augur/application/db/models/augur_operations.py +++ b/augur/application/db/models/augur_operations.py @@ -1,24 +1,20 @@ -# coding: utf-8 +# encoding: utf-8 from sqlalchemy import BigInteger, SmallInteger, Column, Index, Integer, String, Table, text, UniqueConstraint, Boolean, ForeignKey, update, CheckConstraint -from sqlalchemy.dialects.postgresql import TIMESTAMP, UUID +from sqlalchemy.dialects.postgresql import TIMESTAMP from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import relationship -from sqlalchemy.sql import text as sql_text from werkzeug.security import generate_password_hash, check_password_hash -from typing import List, Any, Dict +from typing import List -import logging +import logging import secrets import traceback -import importlib from augur.application.db.models import Repo, RepoGroup from augur.application.db.session import DatabaseSession from augur.application.db.models.base import Base - - FRONTEND_REPO_GROUP_NAME = "Frontend Repos" logger = logging.getLogger(__name__) @@ -39,15 +35,15 @@ def retrieve_owner_repos(session, owner: str) -> List[str]: OWNER_INFO_ENDPOINT = f"https://api.github.com/users/{owner}" ORG_REPOS_ENDPOINT = f"https://api.github.com/orgs/{owner}/repos?per_page=100" USER_REPOS_ENDPOINT = f"https://api.github.com/users/{owner}/repos?per_page=100" - + if not session.oauths.list_of_keys: return None, {"status": "No valid github api keys to retrieve data with"} - + # determine whether the owner is a user or an organization data, _ = retrieve_dict_from_endpoint(logger, session.oauths, OWNER_INFO_ENDPOINT) if not data: return None, {"status": "Invalid owner"} - + owner_type = data["type"] @@ -57,11 +53,11 @@ def retrieve_owner_repos(session, owner: str) -> List[str]: url = ORG_REPOS_ENDPOINT else: return None, {"status": f"Invalid owner type: {owner_type}"} - - + + # collect repo urls for the given owner repos = [] - for page_data, page in GithubPaginator(url, session.oauths, logger).iter_pages(): + for page_data in GithubPaginator(url, session.oauths, logger).iter_pages(): if page_data is None: break @@ -72,7 +68,7 @@ def retrieve_owner_repos(session, owner: str) -> List[str]: return repo_urls, {"status": "success", "owner_type": owner_type} - + metadata = Base.metadata t_all = Table( @@ -266,7 +262,7 @@ class User(Base): tool_version = Column(String) data_source = Column(String) data_collection_date = Column(TIMESTAMP(precision=0), server_default=text("CURRENT_TIMESTAMP")) - + __tablename__ = 'users' __table_args__ = ( UniqueConstraint('email', name='user-unique-email'), @@ -333,8 +329,8 @@ def get_user(session, username: str): return user except NoResultFound: return None - - @staticmethod + + @staticmethod def get_by_id(session, user_id: int): if not isinstance(user_id, int): @@ -344,12 +340,12 @@ def get_by_id(session, user_id: int): return user except NoResultFound: return None - + @staticmethod def create_user(username: str, password: str, email: str, first_name:str, last_name:str, admin=False): if username is None or password is None or email is None or first_name is None or last_name is None: - return False, {"status": "Missing field"} + return False, {"status": "Missing field"} with DatabaseSession(logger) as session: @@ -371,7 +367,7 @@ def create_user(username: str, password: str, email: str, first_name:str, last_n return False, {"status": "Failed to add default group for the user"} return True, {"status": "Account successfully created"} - except AssertionError as exception_message: + except AssertionError as exception_message: return False, {"Error": f"{exception_message}."} def delete(self, session): @@ -410,7 +406,7 @@ def update_email(self, session, new_email): if not new_email: print("Need new email to update the email") return False, {"status": "Missing argument"} - + existing_user = session.query(User).filter(User.email == new_email).first() if existing_user is not None: @@ -454,7 +450,7 @@ def remove_group(self, group_name): return result def add_repo(self, group_name, repo_url): - + from augur.tasks.github.util.github_task_session import GithubTaskSession from augur.tasks.github.util.github_api_key_handler import NoValidKeysError try: @@ -497,20 +493,20 @@ def get_group_names(self, search=None, reversed=False): group_names = [group.name for group in user_groups] else: group_names = [group.name for group in user_groups if search.lower() in group.name.lower()] - + group_names.sort(reverse = reversed) return group_names, {"status": "success"} - + def get_groups_info(self, search=None, reversed=False, sort="group_name"): (groups, result) = self.get_groups() if search is not None: groups = [group for group in groups if search.lower() in group.name.lower()] - + for group in groups: group.count = self.get_group_repo_count(group.name)[0] - + def sorting_function(group): if sort == "group_name": return group.name @@ -518,7 +514,7 @@ def sorting_function(group): return group.count elif sort == "favorited": return group.favorited - + groups = sorted(groups, key=sorting_function, reverse=reversed) return groups, {"status": "success"} @@ -612,7 +608,7 @@ def get_favorite_groups(self, session): return None, {"status": "Error when trying to get favorite groups"} return groups, {"status": "Success"} - + @staticmethod def compute_hashsed_password(password): return generate_password_hash(password, method='pbkdf2:sha512', salt_length=32) @@ -621,7 +617,7 @@ def compute_hashsed_password(password): class UserGroup(Base): group_id = Column(BigInteger, primary_key=True) - user_id = Column(Integer, + user_id = Column(Integer, ForeignKey("augur_operations.users.user_id", name="user_group_user_id_fkey") ) name = Column(String, nullable=False) @@ -797,7 +793,7 @@ def add(session, url: List[str], user_id: int, group_name=None, group_id=None, f if not group_name and not group_id: return False, {"status": "Need group name or group id to add a repo"} - + if from_org_list and not repo_type: return False, {"status": "Repo type must be passed if the repo is from an organization's list of repos"} @@ -806,21 +802,21 @@ def add(session, url: List[str], user_id: int, group_name=None, group_id=None, f group_id = UserGroup.convert_group_name_to_id(session, user_id, group_name) if group_id is None: return False, {"status": "Invalid group name"} - + if not from_org_list: result = Repo.is_valid_github_repo(session, url) if not result[0]: return False, {"status": result[1]["status"], "repo_url": url} - + repo_type = result[1]["repo_type"] - + # if no repo_group_id is passed then assign the repo to the frontend repo group if repo_group_id is None: frontend_repo_group = session.query(RepoGroup).filter(RepoGroup.rg_name == FRONTEND_REPO_GROUP_NAME).first() if not frontend_repo_group: - return False, {"status": "Could not find repo group with name 'Frontend Repos'", "repo_url": url} - + return False, {"status": "Could not find repo group with name 'Frontend Repos'", "repo_url": url} + repo_group_id = frontend_repo_group.repo_group_id @@ -877,7 +873,7 @@ def add_org_repos(session, url: List[str], user_id: int, group_name: int): group_id = UserGroup.convert_group_name_to_id(session, user_id, group_name) if group_id is None: return False, {"status": "Invalid group name"} - + # parse github owner url to get owner name owner = Repo.parse_github_org_url(url) if not owner: @@ -888,10 +884,10 @@ def add_org_repos(session, url: List[str], user_id: int, group_name: int): # if the result is returns None or [] if not result[0]: return False, result[1] - + repos = result[0] type = result[1]["owner_type"] - + # get repo group if it exists try: repo_group = RepoGroup.get_by_name(session, owner) @@ -921,7 +917,7 @@ def add_org_repos(session, url: List[str], user_id: int, group_name: int): if not result[0]: failed_repos.append(repo) - # Update repo group id to new org's repo group id if the repo + # Update repo group id to new org's repo group id if the repo # is a part of the org and existed before org added update_stmt = ( update(Repo) @@ -931,7 +927,7 @@ def add_org_repos(session, url: List[str], user_id: int, group_name: int): ) session.execute(update_stmt) session.commit() - + failed_count = len(failed_repos) if failed_count > 0: # this should never happen because an org should never return invalid repos @@ -959,11 +955,11 @@ class UserSessionToken(Base): @staticmethod def create(session, user_id, application_id, seconds_to_expire=86400): - import time + import time user_session_token = secrets.token_hex() expiration = int(time.time()) + seconds_to_expire - + user_session = UserSessionToken(token=user_session_token, user_id=user_id, application_id = application_id, expiration=expiration) session.add(user_session) @@ -999,12 +995,13 @@ class ClientApplication(Base): sessions = relationship("UserSessionToken") subscriptions = relationship("Subscription") + def __eq__(self, other): + return isinstance(other, ClientApplication) and str(self.id) == str(other.id) + @staticmethod def get_by_id(session, client_id): - return session.query(ClientApplication).filter(ClientApplication.id == client_id).first() - class Subscription(Base): __tablename__ = "subscriptions" __table_args__ = ( @@ -1150,7 +1147,7 @@ class CollectionStatus(Base): issue_pr_sum = Column(BigInteger) commit_sum = Column(BigInteger) - + repo = relationship("Repo", back_populates="collection_status") @staticmethod @@ -1173,7 +1170,7 @@ def insert(session, repo_id): session.logger.error( ''.join(traceback.format_exception(None, e, e.__traceback__))) - + record = { "repo_id": repo_id, "issue_pr_sum": pr_issue_count,